summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/cache/domain/domain.go47
-rw-r--r--internal/cache/domain/domain_test.go24
-rw-r--r--internal/db/bundb/domain_test.go62
3 files changed, 114 insertions, 19 deletions
diff --git a/internal/cache/domain/domain.go b/internal/cache/domain/domain.go
index 1b836ed28..274a244f7 100644
--- a/internal/cache/domain/domain.go
+++ b/internal/cache/domain/domain.go
@@ -19,10 +19,9 @@ package domain
import (
"fmt"
+ "slices"
"strings"
"sync/atomic"
-
- "golang.org/x/exp/slices"
)
// Cache provides a means of caching domains in memory to reduce
@@ -57,6 +56,24 @@ func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, err
return false, fmt.Errorf("error reloading cache: %w", err)
}
+ // Ensure the domains being inserted into the cache
+ // are sorted by number of domain parts. i.e. those
+ // with less parts are inserted last, else this can
+ // allow domains to fall through the matching code!
+ slices.SortFunc(domains, func(a, b string) int {
+ const k = +1
+ an := strings.Count(a, ".")
+ bn := strings.Count(b, ".")
+ switch {
+ case an < bn:
+ return +k
+ case an > bn:
+ return -k
+ default:
+ return 0
+ }
+ })
+
// Allocate new radix trie
// node to store matches.
ptr = new(root)
@@ -94,13 +111,13 @@ type root struct{ root node }
// Add will add the given domain to the radix trie.
func (r *root) Add(domain string) {
- r.root.add(strings.Split(domain, "."))
+ r.root.Add(strings.Split(domain, "."))
}
// Match will return whether the given domain matches
// an existing stored domain in this radix trie.
func (r *root) Match(domain string) bool {
- return r.root.match(strings.Split(domain, "."))
+ return r.root.Match(strings.Split(domain, "."))
}
// Sort will sort the entire radix trie ensuring that
@@ -114,7 +131,7 @@ func (r *root) Sort() {
// String returns a string representation of node (and its descendants).
func (r *root) String() string {
buf := new(strings.Builder)
- r.root.writestr(buf, "")
+ r.root.WriteStr(buf, "")
return buf.String()
}
@@ -123,7 +140,7 @@ type node struct {
child []*node
}
-func (n *node) add(parts []string) {
+func (n *node) Add(parts []string) {
if len(parts) == 0 {
panic("invalid domain")
}
@@ -165,7 +182,7 @@ func (n *node) add(parts []string) {
}
}
-func (n *node) match(parts []string) bool {
+func (n *node) Match(parts []string) bool {
for len(parts) > 0 {
// Pop next domain part.
i := len(parts) - 1
@@ -226,8 +243,16 @@ func (n *node) getChild(part string) *node {
func (n *node) sort() {
// Sort this node's slice of child nodes.
- slices.SortFunc(n.child, func(i, j *node) bool {
- return i.part < j.part
+ slices.SortFunc(n.child, func(i, j *node) int {
+ const k = -1
+ switch {
+ case i.part < j.part:
+ return +k
+ case i.part > j.part:
+ return -k
+ default:
+ return 0
+ }
})
// Sort each child node's children.
@@ -236,7 +261,7 @@ func (n *node) sort() {
}
}
-func (n *node) writestr(buf *strings.Builder, prefix string) {
+func (n *node) WriteStr(buf *strings.Builder, prefix string) {
if prefix != "" {
// Suffix joining '.'
prefix += "."
@@ -251,6 +276,6 @@ func (n *node) writestr(buf *strings.Builder, prefix string) {
// Iterate through node children.
for _, child := range n.child {
- child.writestr(buf, prefix)
+ child.WriteStr(buf, prefix)
}
}
diff --git a/internal/cache/domain/domain_test.go b/internal/cache/domain/domain_test.go
index 9e091e1d0..974425b7c 100644
--- a/internal/cache/domain/domain_test.go
+++ b/internal/cache/domain/domain_test.go
@@ -28,9 +28,13 @@ func TestCache(t *testing.T) {
c := new(domain.Cache)
cachedDomains := []string{
- "google.com",
- "google.co.uk",
- "pleroma.bad.host",
+ "google.com", //
+ "mail.google.com", // should be ignored since covered above
+ "dev.mail.google.com", // same again
+ "google.co.uk", //
+ "mail.google.co.uk", //
+ "pleroma.bad.host", //
+ "pleroma.still.a.bad.host", //
}
loader := func() ([]string, error) {
@@ -38,22 +42,25 @@ func TestCache(t *testing.T) {
return cachedDomains, nil
}
- // Check a list of known cached domains.
+ // Check a list of known matching domains.
for _, domain := range []string{
"google.com",
"mail.google.com",
+ "dev.mail.google.com",
"google.co.uk",
"mail.google.co.uk",
"pleroma.bad.host",
"dev.pleroma.bad.host",
+ "pleroma.still.a.bad.host",
+ "dev.pleroma.still.a.bad.host",
} {
t.Logf("checking domain matches: %s", domain)
if b, _ := c.Matches(domain, loader); !b {
- t.Errorf("domain should be matched: %s", domain)
+ t.Fatalf("domain should be matched: %s", domain)
}
}
- // Check a list of known uncached domains.
+ // Check a list of known unmatched domains.
for _, domain := range []string{
"askjeeves.com",
"ask-kim.co.uk",
@@ -61,10 +68,11 @@ func TestCache(t *testing.T) {
"mail.google.ie",
"gts.bad.host",
"mastodon.bad.host",
+ "akkoma.still.a.bad.host",
} {
t.Logf("checking domain isn't matched: %s", domain)
if b, _ := c.Matches(domain, loader); b {
- t.Errorf("domain should not be matched: %s", domain)
+ t.Fatalf("domain should not be matched: %s", domain)
}
}
@@ -80,6 +88,6 @@ func TestCache(t *testing.T) {
t.Log("load: returning known error")
return nil, knownErr
}); !errors.Is(err, knownErr) {
- t.Errorf("matches did not return expected error: %v", err)
+ t.Fatalf("matches did not return expected error: %v", err)
}
}
diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go
index ff687cf59..8164259e8 100644
--- a/internal/db/bundb/domain_test.go
+++ b/internal/db/bundb/domain_test.go
@@ -19,6 +19,7 @@ package bundb_test
import (
"context"
+ "slices"
"testing"
"time"
@@ -212,6 +213,67 @@ func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII2() {
suite.True(blocked)
}
+func (suite *DomainTestSuite) TestIsOtherDomainBlockedWildcardAndExplicit() {
+ ctx := context.Background()
+
+ blocks := []*gtsmodel.DomainBlock{
+ {
+ ID: "01G204214Y9TNJEBX39C7G88SW",
+ Domain: "bad.apples",
+ CreatedByAccountID: suite.testAccounts["admin_account"].ID,
+ CreatedByAccount: suite.testAccounts["admin_account"],
+ },
+ {
+ ID: "01HKPSVQ864FQ2JJ01CDGPHHMJ",
+ Domain: "some.bad.apples",
+ CreatedByAccountID: suite.testAccounts["admin_account"].ID,
+ CreatedByAccount: suite.testAccounts["admin_account"],
+ },
+ }
+
+ for _, block := range blocks {
+ if err := suite.db.CreateDomainBlock(ctx, block); err != nil {
+ suite.FailNow(err.Error())
+ }
+ }
+
+ // Ensure each block created
+ // above is now present in the db.
+ dbBlocks, err := suite.db.GetDomainBlocks(ctx)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ for _, block := range blocks {
+ if !slices.ContainsFunc(
+ dbBlocks,
+ func(dbBlock *gtsmodel.DomainBlock) bool {
+ return block.Domain == dbBlock.Domain
+ },
+ ) {
+ suite.FailNow("", "stored blocks did not contain %s", block.Domain)
+ }
+ }
+
+ // All domains and subdomains
+ // should now be blocked, even
+ // ones without an explicit block.
+ for _, domain := range []string{
+ "bad.apples",
+ "some.bad.apples",
+ "other.bad.apples",
+ } {
+ blocked, err := suite.db.IsDomainBlocked(ctx, domain)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ if !blocked {
+ suite.Fail("", "domain %s should be blocked", domain)
+ }
+ }
+}
+
func TestDomainTestSuite(t *testing.T) {
suite.Run(t, new(DomainTestSuite))
}