diff --git a/main.go b/main.go index 935084fa..e5fcfb4c 100644 --- a/main.go +++ b/main.go @@ -36,7 +36,10 @@ var ( ) var ( - refMap = make(map[string]*List) + refMap = make(map[string]*List) + plMap = make(map[string]*ParsedList) + finalMap = make(map[string]*List) + cirIncMap = make(map[string]bool) // Used for circular inclusion detection ) type Entry struct { @@ -45,18 +48,24 @@ type Entry struct { Attrs []string } +type Inclusion struct { + Source string + MustAttrs []string + BannedAttrs []string +} + type List struct { Name string Entry []Entry } type ParsedList struct { - Name string - Inclusion map[string]bool - Entry []Entry + Name string + Inclusions []Inclusion + Entry []Entry } -func (l *ParsedList) toPlainText() error { +func (l *List) toPlainText() error { var entryBytes []byte for _, entry := range l.Entry { var attrString string @@ -72,7 +81,7 @@ func (l *ParsedList) toPlainText() error { return nil } -func (l *ParsedList) toProto() (*router.GeoSite, error) { +func (l *List) toProto() (*router.GeoSite, error) { site := &router.GeoSite{ CountryCode: l.Name, } @@ -101,7 +110,7 @@ func (l *ParsedList) toProto() (*router.GeoSite, error) { return site, nil } -func exportPlainTextList(exportFiles []string, entryList *ParsedList) { +func exportPlainTextList(exportFiles []string, entryList *List) { for _, exportfilename := range exportFiles { if strings.EqualFold(entryList.Name, exportfilename) { if err := entryList.toPlainText(); err != nil { @@ -196,39 +205,77 @@ func Load(path string) (*List, error) { } func ParseList(refList *List) (*ParsedList, error) { - pl := &ParsedList{ - Name: refList.Name, - Inclusion: make(map[string]bool), + //TODO: one Entry -> multiple ParsedLists + pl := &ParsedList{Name: refList.Name} + for _, entry := range refList.Entry { + if entry.Type == RuleTypeInclude { + inc := Inclusion{Source: strings.ToUpper(entry.Value)} + for _, attr := range entry.Attrs { + if strings.HasPrefix(attr, "-") { + inc.BannedAttrs = append(inc.BannedAttrs, attr[1:]) // Trim attribute prefix `-` character + } else { + inc.MustAttrs = append(inc.MustAttrs, attr) + } + } + pl.Inclusions = append(pl.Inclusions, inc) + } else { + pl.Entry = append(pl.Entry, entry) + } } - entryList := refList.Entry - for { - newEntryList := make([]Entry, 0, len(entryList)) - hasInclude := false - for _, entry := range entryList { - if entry.Type == RuleTypeInclude { - refName := strings.ToUpper(entry.Value) - if pl.Inclusion[refName] { - continue + return pl, nil +} + +func isMatchAttrFilters(entry Entry, incFilter Inclusion) bool { + attrMap := make(map[string]bool) + for _, attr := range entry.Attrs { + attrMap[attr] = true + } + for _, m := range incFilter.MustAttrs { + if !attrMap[m] { return false } + } + for _, b := range incFilter.BannedAttrs { + if attrMap[b] { return false } + } + return true +} + +func ResolveList(pl *ParsedList) error { + if _, pldone := finalMap[pl.Name]; pldone { return nil } + + if cirIncMap[pl.Name] { + return fmt.Errorf("circular inclusion in: %s", pl.Name) + } + cirIncMap[pl.Name] = true + defer delete(cirIncMap, pl.Name) + + entry2String := func(e Entry) string { // Attributes already sorted + return e.Type+":"+e.Value+"@"+strings.Join(e.Attrs, "@") + } + bscDupMap := make(map[string]bool) // Used for basic duplicates detection + finalList := &List{Name: pl.Name} + + for _, dentry := range pl.Entry { + if dstring := entry2String(dentry); !bscDupMap[dstring] { + bscDupMap[dstring] = true + finalList.Entry = append(finalList.Entry, dentry) + } + } + + for _, inc := range pl.Inclusions { + if err := ResolveList(plMap[inc.Source]); err != nil { + return err + } + for _, ientry := range finalMap[inc.Source].Entry { + if isMatchAttrFilters(ientry, inc) { + if istring := entry2String(ientry); !bscDupMap[istring] { + bscDupMap[istring] = true + finalList.Entry = append(finalList.Entry, ientry) } - pl.Inclusion[refName] = true - refList := refMap[refName] - if refList == nil { - return nil, fmt.Errorf("list not found: %s", entry.Value) - } - newEntryList = append(newEntryList, refList.Entry...) - hasInclude = true - } else { - newEntryList = append(newEntryList, entry) } } - entryList = newEntryList - if !hasInclude { - break - } } - pl.Entry = entryList - - return pl, nil + finalMap[pl.Name] = finalList + return nil } func main() { @@ -237,6 +284,7 @@ func main() { dir := *dataPath fmt.Println("Use domain lists in", dir) + // Generate refMap err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { if err != nil { return err @@ -256,6 +304,24 @@ func main() { os.Exit(1) } + // Generate plMap + for refName, refList := range refMap { + pl, err := ParseList(refList) + if err != nil { + fmt.Println("Failed to ParseList:", err) + os.Exit(1) + } + plMap[refName] = pl + } + + // Generate finalMap + for _, pl := range plMap { + if err := ResolveList(pl); err != nil { + fmt.Println("Failed to ResolveList:", err) + os.Exit(1) + } + } + // Create output directory if not exist if _, err := os.Stat(*outputDir); os.IsNotExist(err) { if mkErr := os.MkdirAll(*outputDir, 0755); mkErr != nil { @@ -266,13 +332,8 @@ func main() { protoList := new(router.GeoSiteList) var existList []string - for _, refList := range refMap { - pl, err := ParseList(refList) - if err != nil { - fmt.Println("Failed:", err) - os.Exit(1) - } - site, err := pl.toProto() + for _, siteEntries := range finalMap { + site, err := siteEntries.toProto() if err != nil { fmt.Println("Failed:", err) os.Exit(1) @@ -282,7 +343,7 @@ func main() { // Flatten and export plaintext list if *exportLists != "" { if existList != nil { - exportPlainTextList(existList, pl) + exportPlainTextList(existList, siteEntries) } else { exportedListSlice := strings.Split(*exportLists, ",") for _, exportedListName := range exportedListSlice { @@ -295,7 +356,7 @@ func main() { } } if existList != nil { - exportPlainTextList(existList, pl) + exportPlainTextList(existList, siteEntries) } } } @@ -308,11 +369,11 @@ func main() { protoBytes, err := proto.Marshal(protoList) if err != nil { - fmt.Println("Failed:", err) + fmt.Println("Failed to marshal:", err) os.Exit(1) } if err := os.WriteFile(filepath.Join(*outputDir, *outputName), protoBytes, 0644); err != nil { - fmt.Println("Failed:", err) + fmt.Println("Failed to write output:", err) os.Exit(1) } else { fmt.Println(*outputName, "has been generated successfully.")