diff options
Diffstat (limited to 'internal/cache/domain/domain.go')
-rw-r--r-- | internal/cache/domain/domain.go | 47 |
1 files changed, 36 insertions, 11 deletions
diff --git a/internal/cache/domain/domain.go b/internal/cache/domain/domain.go index 1b836ed28..274a244f7 100644 --- a/internal/cache/domain/domain.go +++ b/internal/cache/domain/domain.go @@ -19,10 +19,9 @@ package domain import ( "fmt" + "slices" "strings" "sync/atomic" - - "golang.org/x/exp/slices" ) // Cache provides a means of caching domains in memory to reduce @@ -57,6 +56,24 @@ func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, err return false, fmt.Errorf("error reloading cache: %w", err) } + // Ensure the domains being inserted into the cache + // are sorted by number of domain parts. i.e. those + // with less parts are inserted last, else this can + // allow domains to fall through the matching code! + slices.SortFunc(domains, func(a, b string) int { + const k = +1 + an := strings.Count(a, ".") + bn := strings.Count(b, ".") + switch { + case an < bn: + return +k + case an > bn: + return -k + default: + return 0 + } + }) + // Allocate new radix trie // node to store matches. ptr = new(root) @@ -94,13 +111,13 @@ type root struct{ root node } // Add will add the given domain to the radix trie. func (r *root) Add(domain string) { - r.root.add(strings.Split(domain, ".")) + r.root.Add(strings.Split(domain, ".")) } // Match will return whether the given domain matches // an existing stored domain in this radix trie. func (r *root) Match(domain string) bool { - return r.root.match(strings.Split(domain, ".")) + return r.root.Match(strings.Split(domain, ".")) } // Sort will sort the entire radix trie ensuring that @@ -114,7 +131,7 @@ func (r *root) Sort() { // String returns a string representation of node (and its descendants). func (r *root) String() string { buf := new(strings.Builder) - r.root.writestr(buf, "") + r.root.WriteStr(buf, "") return buf.String() } @@ -123,7 +140,7 @@ type node struct { child []*node } -func (n *node) add(parts []string) { +func (n *node) Add(parts []string) { if len(parts) == 0 { panic("invalid domain") } @@ -165,7 +182,7 @@ func (n *node) add(parts []string) { } } -func (n *node) match(parts []string) bool { +func (n *node) Match(parts []string) bool { for len(parts) > 0 { // Pop next domain part. i := len(parts) - 1 @@ -226,8 +243,16 @@ func (n *node) getChild(part string) *node { func (n *node) sort() { // Sort this node's slice of child nodes. - slices.SortFunc(n.child, func(i, j *node) bool { - return i.part < j.part + slices.SortFunc(n.child, func(i, j *node) int { + const k = -1 + switch { + case i.part < j.part: + return +k + case i.part > j.part: + return -k + default: + return 0 + } }) // Sort each child node's children. @@ -236,7 +261,7 @@ func (n *node) sort() { } } -func (n *node) writestr(buf *strings.Builder, prefix string) { +func (n *node) WriteStr(buf *strings.Builder, prefix string) { if prefix != "" { // Suffix joining '.' prefix += "." @@ -251,6 +276,6 @@ func (n *node) writestr(buf *strings.Builder, prefix string) { // Iterate through node children. for _, child := range n.child { - child.writestr(buf, prefix) + child.WriteStr(buf, prefix) } } |