diff --git a/main.go b/main.go index 4a141418..489f4190 100644 --- a/main.go +++ b/main.go @@ -29,6 +29,12 @@ const ( RuleTypeInclude string = "include" ) +var ( + TypeChecker = regexp.MustCompile(`^(domain|full|keyword|regexp|include)$`) + ValueChecker = regexp.MustCompile(`^[a-z0-9!\.-]+$`) + AttrChecker = regexp.MustCompile(`^[a-z0-9!-]+$`) +) + type Entry struct { Type string Value string @@ -80,18 +86,11 @@ func (l *ParsedList) toProto() (*router.GeoSite, error) { case RuleTypeDomain: pdomain.Type = router.Domain_RootDomain case RuleTypeRegexp: - // check regexp validity to avoid runtime error - _, err := regexp.Compile(entry.Value) - if err != nil { - return nil, fmt.Errorf("invalid regexp in list %s: %s", l.Name, entry.Value) - } pdomain.Type = router.Domain_Regex case RuleTypeKeyword: pdomain.Type = router.Domain_Plain case RuleTypeFullDomain: pdomain.Type = router.Domain_Full - default: - return nil, fmt.Errorf("unknown domain type: %s", entry.Type) } site.Domain = append(site.Domain, pdomain) } @@ -110,58 +109,53 @@ func exportPlainTextList(list []string, refName string, pl *ParsedList) { } } -func parseDomain(domain string, entry *Entry) error { - kv := strings.Split(domain, ":") +func parseEntry(line string) (Entry, error) { + var entry Entry + parts := strings.Fields(line) + + // Parse type and value + rawTypeVal := parts[0] + kv := strings.Split(rawTypeVal, ":") if len(kv) == 1 { - entry.Type = RuleTypeDomain - entry.Value = strings.ToLower(kv[0]) - return nil - } - - if len(kv) == 2 { + entry.Type = RuleTypeDomain // Default type + entry.Value = strings.ToLower(rawTypeVal) + } else if len(kv) == 2 { entry.Type = strings.ToLower(kv[0]) - - if strings.EqualFold(entry.Type, RuleTypeRegexp) { + if entry.Type == RuleTypeRegexp { entry.Value = kv[1] } else { entry.Value = strings.ToLower(kv[1]) } - - return nil + } else { + return entry, fmt.Errorf("invalid format: %s", line) + } + // Check type and value + if !TypeChecker.MatchString(entry.Type) { + return entry, fmt.Errorf("invalid type: %s", entry.Type) + } + if entry.Type == RuleTypeRegexp { + if _, err := regexp.Compile(entry.Value); err != nil { + return entry, fmt.Errorf("invalid regexp: %s", entry.Value) + } + } else if !ValueChecker.MatchString(entry.Value) { + return entry, fmt.Errorf("invalid value: %s", entry.Value) } - return fmt.Errorf("invalid format: %s", domain) -} - -func parseAttribute(attr string) (string, error) { - var attribute string - if len(attr) == 0 || attr[0] != '@' { - return attribute, fmt.Errorf("invalid attribute: %s", attr) - } - - attribute = strings.ToLower(attr[1:]) // Trim attribute prefix `@` character - return attribute, nil -} - -func parseEntry(line string) (Entry, error) { - parts := strings.Split(line, " ") - - var entry Entry - if len(parts) == 0 { - return entry, fmt.Errorf("empty entry") - } - - if err := parseDomain(parts[0], &entry); err != nil { - return entry, err - } - - for i := 1; i < len(parts); i++ { - attr, err := parseAttribute(parts[i]) - if err != nil { - return entry, err + // Parse/Check attributes + for _, part := range parts[1:] { + if !strings.HasPrefix(part, "@") { + return entry, fmt.Errorf("invalid attribute: %s", part) + } + attr := strings.ToLower(part[1:]) // Trim attribute prefix `@` character + if !AttrChecker.MatchString(attr) { + return entry, fmt.Errorf("invalid attribute key: %s", attr) } entry.Attrs = append(entry.Attrs, attr) } + // Sort attributes + sort.Slice(entry.Attrs, func(i, j int) bool { + return entry.Attrs[i] < entry.Attrs[j] + }) return entry, nil }