From 714a061ba39d7c7528c43b1bf4901158a93d6d7d Mon Sep 17 00:00:00 2001 From: MkQtS <81752398+MkQtS@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:58:47 +0800 Subject: [PATCH] main.go: improve codes (#3366) * main.go: improve codes * main.go: add parseInclusion - seprate from parseEntry - not allow affiliation for inclusion --- main.go | 197 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 111 insertions(+), 86 deletions(-) diff --git a/main.go b/main.go index 7baba279..66f61ec6 100644 --- a/main.go +++ b/main.go @@ -47,7 +47,7 @@ type Processor struct { cirIncMap map[string]bool } -func makeProtoList(listName string, entries []*Entry) (*router.GeoSite, error) { +func makeProtoList(listName string, entries []*Entry) *router.GeoSite { site := &router.GeoSite{ CountryCode: listName, Domain: make([]*router.Domain, 0, len(entries)), @@ -73,7 +73,7 @@ func makeProtoList(listName string, entries []*Entry) (*router.GeoSite, error) { } site.Domain = append(site.Domain, pdomain) } - return site, nil + return site } func writePlainList(listname string, entries []*Entry) error { @@ -89,46 +89,28 @@ func writePlainList(listname string, entries []*Entry) error { return w.Flush() } -func parseEntry(line string) (*Entry, []string, error) { - entry := new(Entry) - parts := strings.Fields(line) +func parseEntry(typ, rule string) (*Entry, []string, error) { + entry := &Entry{Type: typ} + parts := strings.Fields(rule) if len(parts) == 0 { - return entry, nil, fmt.Errorf("empty line") + return entry, nil, fmt.Errorf("empty domain rule") } - - // Parse type and value - typ, val, isTypeSpecified := strings.Cut(parts[0], ":") - typ = strings.ToLower(typ) - if !isTypeSpecified { // Default RuleType - if !validateDomainChars(typ) { - return entry, nil, fmt.Errorf("invalid domain: %q", typ) + // Parse value + switch entry.Type { + case dlc.RuleTypeRegexp: + if _, err := regexp.Compile(parts[0]); err != nil { + return entry, nil, fmt.Errorf("invalid regexp %q: %w", parts[0], err) } - entry.Type = dlc.RuleTypeDomain - entry.Value = typ - } else { - switch typ { - case dlc.RuleTypeRegexp: - if _, err := regexp.Compile(val); err != nil { - return entry, nil, fmt.Errorf("invalid regexp %q: %w", val, err) - } - entry.Type = dlc.RuleTypeRegexp - entry.Value = val - case dlc.RuleTypeInclude: - entry.Type = dlc.RuleTypeInclude - entry.Value = strings.ToUpper(val) - if !validateSiteName(entry.Value) { - return entry, nil, fmt.Errorf("invalid included list name: %q", entry.Value) - } - case dlc.RuleTypeDomain, dlc.RuleTypeFullDomain, dlc.RuleTypeKeyword: - entry.Type = typ - entry.Value = strings.ToLower(val) - if !validateDomainChars(entry.Value) { - return entry, nil, fmt.Errorf("invalid domain: %q", entry.Value) - } - default: - return entry, nil, fmt.Errorf("invalid type: %q", typ) + entry.Value = parts[0] + case dlc.RuleTypeDomain, dlc.RuleTypeFullDomain, dlc.RuleTypeKeyword: + entry.Value = strings.ToLower(parts[0]) + if !validateDomainChars(entry.Value) { + return entry, nil, fmt.Errorf("invalid domain: %q", entry.Value) } + default: + return entry, nil, fmt.Errorf("unknown rule type: %q", entry.Type) } + plen := len(entry.Type) + len(entry.Value) + 1 // Parse attributes and affiliations var affs []string @@ -140,6 +122,7 @@ func parseEntry(line string) (*Entry, []string, error) { return entry, affs, fmt.Errorf("invalid attribute: %q", attr) } entry.Attrs = append(entry.Attrs, attr) + plen += 2 + len(attr) case '&': aff := strings.ToUpper(part[1:]) if !validateSiteName(aff) { @@ -147,33 +130,70 @@ func parseEntry(line string) (*Entry, []string, error) { } affs = append(affs, aff) default: - return entry, affs, fmt.Errorf("invalid attribute/affiliation: %q", part) + return entry, affs, fmt.Errorf("unknown field: %q", part) } } - if entry.Type != dlc.RuleTypeInclude { - slices.Sort(entry.Attrs) // Sort attributes - // Formated plain entry: type:domain.tld:@attr1,@attr2 - var plain strings.Builder - plain.Grow(len(entry.Type) + len(entry.Value) + 10) - plain.WriteString(entry.Type) - plain.WriteByte(':') - plain.WriteString(entry.Value) - for i, attr := range entry.Attrs { - if i == 0 { - plain.WriteByte(':') - } else { - plain.WriteByte(',') - } - plain.WriteByte('@') - plain.WriteString(attr) + slices.Sort(entry.Attrs) // Sort attributes + // Formated plain entry: type:domain.tld:@attr1,@attr2 + var plain strings.Builder + plain.Grow(plen) + plain.WriteString(entry.Type) + plain.WriteByte(':') + plain.WriteString(entry.Value) + for i, attr := range entry.Attrs { + if i == 0 { + plain.WriteByte(':') + } else { + plain.WriteByte(',') } - entry.Plain = plain.String() + plain.WriteByte('@') + plain.WriteString(attr) } + entry.Plain = plain.String() return entry, affs, nil } +func parseInclusion(rule string) (*Inclusion, error) { + parts := strings.Fields(rule) + if len(parts) == 0 { + return nil, fmt.Errorf("empty inclusion") + } + inc := &Inclusion{Source: strings.ToUpper(parts[0])} + if !validateSiteName(inc.Source) { + return inc, fmt.Errorf("invalid included list name: %q", inc.Source) + } + + // Parse attributes + for _, part := range parts[1:] { + switch part[0] { + case '@': + attr := strings.ToLower(part[1:]) + if attr[0] == '-' { + battr := attr[1:] + if !validateAttrChars(battr) { + return inc, fmt.Errorf("invalid ban attribute: %q", battr) + } + inc.BanAttrs = append(inc.BanAttrs, battr) + } else { + if !validateAttrChars(attr) { + return inc, fmt.Errorf("invalid must attribute: %q", attr) + } + inc.MustAttrs = append(inc.MustAttrs, attr) + } + case '&': + return inc, fmt.Errorf("affiliation is not allowed for inclusion") + default: + return inc, fmt.Errorf("unknown field: %q", part) + } + } + return inc, nil +} + func validateDomainChars(domain string) bool { + if domain == "" { + return false + } for i := range domain { c := domain[i] if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '.' || c == '-' { @@ -185,9 +205,12 @@ func validateDomainChars(domain string) bool { } func validateAttrChars(attr string) bool { + if attr == "" { + return false + } for i := range attr { c := attr[i] - if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '!' || c == '-' { + if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '!' { continue } return false @@ -196,6 +219,9 @@ func validateAttrChars(attr string) bool { } func validateSiteName(name string) bool { + if name == "" { + return false + } for i := range name { c := name[i] if (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '!' || c == '-' { @@ -232,26 +258,23 @@ func (p *Processor) loadData(listName string, path string) error { if line == "" { continue } - entry, affs, err := parseEntry(line) - if err != nil { - return fmt.Errorf("error in %q at line %d: %w", path, lineIdx, err) + typ, rule, isTypeSpecified := strings.Cut(line, ":") + if !isTypeSpecified { // Default RuleType + typ, rule = dlc.RuleTypeDomain, typ + } else { + typ = strings.ToLower(typ) } - - if entry.Type == dlc.RuleTypeInclude { - inc := &Inclusion{Source: entry.Value} - for _, attr := range entry.Attrs { - if attr[0] == '-' { - inc.BanAttrs = append(inc.BanAttrs, attr[1:]) - } else { - inc.MustAttrs = append(inc.MustAttrs, attr) - } - } - for _, aff := range affs { - apl := p.getOrCreateParsedList(aff) - apl.Inclusions = append(apl.Inclusions, inc) + if typ == dlc.RuleTypeInclude { + inc, err := parseInclusion(rule) + if err != nil { + return fmt.Errorf("error in %q at line %d: %w", path, lineIdx, err) } pl.Inclusions = append(pl.Inclusions, inc) } else { + entry, affs, err := parseEntry(typ, rule) + if err != nil { + return fmt.Errorf("error in %q at line %d: %w", path, lineIdx, err) + } for _, aff := range affs { apl := p.getOrCreateParsedList(aff) apl.Entries = append(apl.Entries, entry) @@ -259,7 +282,7 @@ func (p *Processor) loadData(listName string, path string) error { pl.Entries = append(pl.Entries, entry) } } - return nil + return scanner.Err() } func isMatchAttrFilters(entry *Entry, incFilter *Inclusion) bool { @@ -360,6 +383,9 @@ func (p *Processor) resolveList(plname string) error { } } } + if len(roughMap) == 0 { + return fmt.Errorf("empty list") + } p.finalMap[plname] = polishList(roughMap) return nil } @@ -387,13 +413,15 @@ func run() error { return fmt.Errorf("failed to loadData: %w", err) } // Generate finalMap - processor.finalMap = make(map[string][]*Entry, len(processor.plMap)) + sitesCount := len(processor.plMap) + processor.finalMap = make(map[string][]*Entry, sitesCount) processor.cirIncMap = make(map[string]bool) for plname := range processor.plMap { if err := processor.resolveList(plname); err != nil { return fmt.Errorf("failed to resolveList %q: %w", plname, err) } } + processor.plMap = nil // Make sure output directory exists if err := os.MkdirAll(*outputDir, 0755); err != nil { @@ -403,27 +431,24 @@ func run() error { for rawEpList := range strings.SplitSeq(*exportLists, ",") { if epList := strings.TrimSpace(rawEpList); epList != "" { entries, exist := processor.finalMap[strings.ToUpper(epList)] - if !exist || len(entries) == 0 { - fmt.Printf("list %q does not exist or is empty\n", epList) + if !exist { + fmt.Printf("[Warn] list %q does not exist\n", epList) continue } if err := writePlainList(epList, entries); err != nil { - fmt.Printf("failed to write list %q: %v\n", epList, err) + fmt.Printf("[Error] failed to write list %q: %v\n", epList, err) continue } - fmt.Printf("list %q has been generated successfully.\n", epList) + fmt.Printf("list %q has been generated successfully\n", epList) } } // Generate dat file - protoList := new(router.GeoSiteList) + protoList := &router.GeoSiteList{Entry: make([]*router.GeoSite, 0, sitesCount)} for siteName, siteEntries := range processor.finalMap { - site, err := makeProtoList(siteName, siteEntries) - if err != nil { - return fmt.Errorf("failed to makeProtoList %q: %w", siteName, err) - } - protoList.Entry = append(protoList.Entry, site) + protoList.Entry = append(protoList.Entry, makeProtoList(siteName, siteEntries)) } + processor = nil // Sort protoList so the marshaled list is reproducible slices.SortFunc(protoList.Entry, func(a, b *router.GeoSite) int { return strings.Compare(a.CountryCode, b.CountryCode) @@ -436,14 +461,14 @@ func run() error { if err := os.WriteFile(filepath.Join(*outputDir, *outputName), protoBytes, 0644); err != nil { return fmt.Errorf("failed to write output: %w", err) } - fmt.Printf("%q has been generated successfully.\n", *outputName) + fmt.Printf("%q has been generated successfully\n", *outputName) return nil } func main() { flag.Parse() if err := run(); err != nil { - fmt.Printf("Fatal error: %v\n", err) + fmt.Printf("[Fatal] critical error: %v\n", err) os.Exit(1) } }