summary refs log tree commit diff
diff options
context:
space:
mode:
authorkim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>2024-01-09 13:12:43 +0000
committertobi <tobi.smethurst@protonmail.com>2024-01-10 13:56:17 +0100
commitccecf5a7e4000f40522a790fcdaf2bf72d43d552 (patch)
treef3480d86d96bc1e54930f49961ce7321264e8745
parentd5c305dc6e3275589d2931bcb0ae7d912c9ab04a (diff)
[bugfix] fix higher-level explicit domain rules causing issues with lower-level domain blocking (#2513)
* fix the sort direction of domain cache child nodes ...

* add more domain cache test cases

* add specific test for this bug to database domain test suite (thanks for writing this @tsmethurst!)

* remove unused field (this was a previous attempt at a fix)

* remove debugging println statements :innocent:
-rw-r--r--internal/cache/domain/domain.go47
-rw-r--r--internal/cache/domain/domain_test.go24
-rw-r--r--internal/db/bundb/domain_test.go62
3 files changed, 114 insertions, 19 deletions
diff --git a/internal/cache/domain/domain.go b/internal/cache/domain/domain.go
index 051ec5c1..e612f326 100644
--- a/internal/cache/domain/domain.go
+++ b/internal/cache/domain/domain.go
@@ -19,11 +19,10 @@ package domain
 
 import (
 	"fmt"
+	"slices"
 	"strings"
 	"sync/atomic"
 	"unsafe"
-
-	"golang.org/x/exp/slices"
 )
 
 // Cache provides a means of caching domains in memory to reduce
@@ -58,6 +57,24 @@ func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, err
 			return false, fmt.Errorf("error reloading cache: %w", err)
 		}
 
+		// Ensure the domains being inserted into the cache
+		// are sorted by number of domain parts. i.e. those
+		// with less parts are inserted last, else this can
+		// allow domains to fall through the matching code!
+		slices.SortFunc(domains, func(a, b string) int {
+			const k = +1
+			an := strings.Count(a, ".")
+			bn := strings.Count(b, ".")
+			switch {
+			case an < bn:
+				return +k
+			case an > bn:
+				return -k
+			default:
+				return 0
+			}
+		})
+
 		// Allocate new radix trie
 		// node to store matches.
 		root := new(root)
@@ -98,13 +115,13 @@ type root struct{ root node }
 
 // Add will add the given domain to the radix trie.
 func (r *root) Add(domain string) {
-	r.root.add(strings.Split(domain, "."))
+	r.root.Add(strings.Split(domain, "."))
 }
 
 // Match will return whether the given domain matches
 // an existing stored domain in this radix trie.
 func (r *root) Match(domain string) bool {
-	return r.root.match(strings.Split(domain, "."))
+	return r.root.Match(strings.Split(domain, "."))
 }
 
 // Sort will sort the entire radix trie ensuring that
@@ -118,7 +135,7 @@ func (r *root) Sort() {
 // String returns a string representation of node (and its descendants).
 func (r *root) String() string {
 	buf := new(strings.Builder)
-	r.root.writestr(buf, "")
+	r.root.WriteStr(buf, "")
 	return buf.String()
 }
 
@@ -127,7 +144,7 @@ type node struct {
 	child []*node
 }
 
-func (n *node) add(parts []string) {
+func (n *node) Add(parts []string) {
 	if len(parts) == 0 {
 		panic("invalid domain")
 	}
@@ -169,7 +186,7 @@ func (n *node) add(parts []string) {
 	}
 }
 
-func (n *node) match(parts []string) bool {
+func (n *node) Match(parts []string) bool {
 	for len(parts) > 0 {
 		// Pop next domain part.
 		i := len(parts) - 1
@@ -230,8 +247,16 @@ func (n *node) getChild(part string) *node {
 
 func (n *node) sort() {
 	// Sort this node's slice of child nodes.
-	slices.SortFunc(n.child, func(i, j *node) bool {
-		return i.part < j.part
+	slices.SortFunc(n.child, func(i, j *node) int {
+		const k = -1
+		switch {
+		case i.part < j.part:
+			return +k
+		case i.part > j.part:
+			return -k
+		default:
+			return 0
+		}
 	})
 
 	// Sort each child node's children.
@@ -240,7 +265,7 @@ func (n *node) sort() {
 	}
 }
 
-func (n *node) writestr(buf *strings.Builder, prefix string) {
+func (n *node) WriteStr(buf *strings.Builder, prefix string) {
 	if prefix != "" {
 		// Suffix joining '.'
 		prefix += "."
@@ -255,6 +280,6 @@ func (n *node) writestr(buf *strings.Builder, prefix string) {
 
 	// Iterate through node children.
 	for _, child := range n.child {
-		child.writestr(buf, prefix)
+		child.WriteStr(buf, prefix)
 	}
 }
diff --git a/internal/cache/domain/domain_test.go b/internal/cache/domain/domain_test.go
index 9e091e1d..974425b7 100644
--- a/internal/cache/domain/domain_test.go
+++ b/internal/cache/domain/domain_test.go
@@ -28,9 +28,13 @@ func TestCache(t *testing.T) {
 	c := new(domain.Cache)
 
 	cachedDomains := []string{
-		"google.com",
-		"google.co.uk",
-		"pleroma.bad.host",
+		"google.com",               //
+		"mail.google.com",          // should be ignored since covered above
+		"dev.mail.google.com",      // same again
+		"google.co.uk",             //
+		"mail.google.co.uk",        //
+		"pleroma.bad.host",         //
+		"pleroma.still.a.bad.host", //
 	}
 
 	loader := func() ([]string, error) {
@@ -38,22 +42,25 @@ func TestCache(t *testing.T) {
 		return cachedDomains, nil
 	}
 
-	// Check a list of known cached domains.
+	// Check a list of known matching domains.
 	for _, domain := range []string{
 		"google.com",
 		"mail.google.com",
+		"dev.mail.google.com",
 		"google.co.uk",
 		"mail.google.co.uk",
 		"pleroma.bad.host",
 		"dev.pleroma.bad.host",
+		"pleroma.still.a.bad.host",
+		"dev.pleroma.still.a.bad.host",
 	} {
 		t.Logf("checking domain matches: %s", domain)
 		if b, _ := c.Matches(domain, loader); !b {
-			t.Errorf("domain should be matched: %s", domain)
+			t.Fatalf("domain should be matched: %s", domain)
 		}
 	}
 
-	// Check a list of known uncached domains.
+	// Check a list of known unmatched domains.
 	for _, domain := range []string{
 		"askjeeves.com",
 		"ask-kim.co.uk",
@@ -61,10 +68,11 @@ func TestCache(t *testing.T) {
 		"mail.google.ie",
 		"gts.bad.host",
 		"mastodon.bad.host",
+		"akkoma.still.a.bad.host",
 	} {
 		t.Logf("checking domain isn't matched: %s", domain)
 		if b, _ := c.Matches(domain, loader); b {
-			t.Errorf("domain should not be matched: %s", domain)
+			t.Fatalf("domain should not be matched: %s", domain)
 		}
 	}
 
@@ -80,6 +88,6 @@ func TestCache(t *testing.T) {
 		t.Log("load: returning known error")
 		return nil, knownErr
 	}); !errors.Is(err, knownErr) {
-		t.Errorf("matches did not return expected error: %v", err)
+		t.Fatalf("matches did not return expected error: %v", err)
 	}
 }
diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go
index ff687cf5..8164259e 100644
--- a/internal/db/bundb/domain_test.go
+++ b/internal/db/bundb/domain_test.go
@@ -19,6 +19,7 @@ package bundb_test
 
 import (
 	"context"
+	"slices"
 	"testing"
 	"time"
 
@@ -212,6 +213,67 @@ func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII2() {
 	suite.True(blocked)
 }
 
+func (suite *DomainTestSuite) TestIsOtherDomainBlockedWildcardAndExplicit() {
+	ctx := context.Background()
+
+	blocks := []*gtsmodel.DomainBlock{
+		{
+			ID:                 "01G204214Y9TNJEBX39C7G88SW",
+			Domain:             "bad.apples",
+			CreatedByAccountID: suite.testAccounts["admin_account"].ID,
+			CreatedByAccount:   suite.testAccounts["admin_account"],
+		},
+		{
+			ID:                 "01HKPSVQ864FQ2JJ01CDGPHHMJ",
+			Domain:             "some.bad.apples",
+			CreatedByAccountID: suite.testAccounts["admin_account"].ID,
+			CreatedByAccount:   suite.testAccounts["admin_account"],
+		},
+	}
+
+	for _, block := range blocks {
+		if err := suite.db.CreateDomainBlock(ctx, block); err != nil {
+			suite.FailNow(err.Error())
+		}
+	}
+
+	// Ensure each block created
+	// above is now present in the db.
+	dbBlocks, err := suite.db.GetDomainBlocks(ctx)
+	if err != nil {
+		suite.FailNow(err.Error())
+	}
+
+	for _, block := range blocks {
+		if !slices.ContainsFunc(
+			dbBlocks,
+			func(dbBlock *gtsmodel.DomainBlock) bool {
+				return block.Domain == dbBlock.Domain
+			},
+		) {
+			suite.FailNow("", "stored blocks did not contain %s", block.Domain)
+		}
+	}
+
+	// All domains and subdomains
+	// should now be blocked, even
+	// ones without an explicit block.
+	for _, domain := range []string{
+		"bad.apples",
+		"some.bad.apples",
+		"other.bad.apples",
+	} {
+		blocked, err := suite.db.IsDomainBlocked(ctx, domain)
+		if err != nil {
+			suite.FailNow(err.Error())
+		}
+
+		if !blocked {
+			suite.Fail("", "domain %s should be blocked", domain)
+		}
+	}
+}
+
 func TestDomainTestSuite(t *testing.T) {
 	suite.Run(t, new(DomainTestSuite))
 }