diff options
| -rw-r--r-- | internal/cache/domain/domain.go | 170 | ||||
| -rw-r--r-- | internal/cache/domain/domain_test.go | 85 | ||||
| -rw-r--r-- | internal/cache/gts.go | 21 | ||||
| -rw-r--r-- | internal/db/bundb/domain.go | 81 | ||||
| -rw-r--r-- | internal/db/bundb/domain_test.go | 32 | 
5 files changed, 350 insertions, 39 deletions
diff --git a/internal/cache/domain/domain.go b/internal/cache/domain/domain.go new file mode 100644 index 000000000..4697f05a6 --- /dev/null +++ b/internal/cache/domain/domain.go @@ -0,0 +1,170 @@ +/* +   GoToSocial +   Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package domain + +import ( +	"fmt" +	"time" + +	"codeberg.org/gruf/go-cache/v3/ttl" +	"github.com/miekg/dns" +) + +// BlockCache provides a means of caching domain blocks in memory to reduce load +// on an underlying storage mechanism, e.g. a database. +// +// It consists of a TTL primary cache that stores calculated domain string to block results, +// that on cache miss is filled by calculating block status by iterating over a list of all of +// the domain blocks stored in memory. This reduces CPU usage required by not need needing to +// iterate through a possible 100-1000s long block list, while saving memory by having a primary +// cache of limited size that evicts stale entries. The raw list of all domain blocks should in +// most cases be negligible when it comes to memory usage. +// +// The in-memory block list is kept up-to-date by means of a passed loader function during every +// call to .IsBlocked(). In the case of a nil internal block list, the loader function is called to +// hydrate the cache with the latest list of domain blocks. The .Clear() function can be used to invalidate +// the cache, e.g. when a domain block is added / deleted from the database. It will drop the current +// list of domain blocks and clear all entries from the primary cache. +type BlockCache struct { +	pcache *ttl.Cache[string, bool] // primary cache of domains -> block results +	blocks []block                  // raw list of all domain blocks, nil => not loaded. +} + +// New returns a new initialized BlockCache instance with given primary cache capacity and TTL. +func New(pcap int, pttl time.Duration) *BlockCache { +	c := new(BlockCache) +	c.pcache = new(ttl.Cache[string, bool]) +	c.pcache.Init(0, pcap, pttl) +	return c +} + +// Start will start the cache background eviction routine with given sweep frequency. If already running or a freq <= 0 provided, this is a no-op. This will block until the eviction routine has started. +func (b *BlockCache) Start(pfreq time.Duration) bool { +	return b.pcache.Start(pfreq) +} + +// Stop will stop cache background eviction routine. If not running this is a no-op. This will block until the eviction routine has stopped. +func (b *BlockCache) Stop() bool { +	return b.pcache.Stop() +} + +// IsBlocked checks whether domain is blocked. If the cache is not currently loaded, then the provided load function is used to hydrate it. +// NOTE: be VERY careful using any kind of locking mechanism within the load function, as this itself is ran within the cache mutex lock. +func (b *BlockCache) IsBlocked(domain string, load func() ([]string, error)) (bool, error) { +	var blocked bool + +	// Acquire cache lock +	b.pcache.Lock() +	defer b.pcache.Unlock() + +	// Check primary cache for result +	entry, ok := b.pcache.Cache.Get(domain) +	if ok { +		return entry.Value, nil +	} + +	if b.blocks == nil { +		// Cache is not hydrated +		// +		// Load domains from callback +		domains, err := load() +		if err != nil { +			return false, fmt.Errorf("error reloading cache: %w", err) +		} + +		// Drop all domain blocks and recreate +		b.blocks = make([]block, len(domains)) + +		for i, domain := range domains { +			// Store pre-split labels for each domain block +			b.blocks[i].labels = dns.SplitDomainName(domain) +		} +	} + +	// Split domain into it separate labels +	labels := dns.SplitDomainName(domain) + +	// Compare this to our stored blocks +	for _, block := range b.blocks { +		if block.Blocks(labels) { +			blocked = true +			break +		} +	} + +	// Store block result in primary cache +	b.pcache.Cache.Set(domain, &ttl.Entry[string, bool]{ +		Key:    domain, +		Value:  blocked, +		Expiry: time.Now().Add(b.pcache.TTL), +	}) + +	return blocked, nil +} + +// Clear will drop the currently loaded domain list, and clear the primary cache. +// This will trigger a reload on next call to .IsBlocked(). +func (b *BlockCache) Clear() { +	// Drop all blocks. +	b.pcache.Lock() +	b.blocks = nil +	b.pcache.Unlock() + +	// Clear needs to be done _outside_ of +	// lock, as also acquires a mutex lock. +	b.pcache.Clear() +} + +// block represents a domain block, and stores the +// deconstructed labels of a singular domain block. +// e.g. []string{"gts", "superseriousbusiness", "org"}. +type block struct { +	labels []string +} + +// Blocks checks whether the separated domain labels of an +// incoming domain matches the stored (receiving struct) block. +func (b block) Blocks(labels []string) bool { +	// Calculate length difference +	d := len(labels) - len(b.labels) +	if d < 0 { +		return false +	} + +	// Iterate backwards through domain block's +	// labels, omparing against the incoming domain's. +	// +	// So for the following input: +	// labels   = []string{"mail", "google", "com"} +	// b.labels = []string{"google", "com"} +	// +	// These would be matched in reverse order along +	// the entirety of the block object's labels: +	// "com"    => match +	// "google" => match +	// +	// And so would reach the end and return true. +	for i := len(b.labels) - 1; i >= 0; i-- { +		if b.labels[i] != labels[i+d] { +			return false +		} +	} + +	return true +} diff --git a/internal/cache/domain/domain_test.go b/internal/cache/domain/domain_test.go new file mode 100644 index 000000000..416ce5012 --- /dev/null +++ b/internal/cache/domain/domain_test.go @@ -0,0 +1,85 @@ +/* +   GoToSocial +   Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package domain_test + +import ( +	"errors" +	"testing" +	"time" + +	"github.com/superseriousbusiness/gotosocial/internal/cache/domain" +) + +func TestBlockCache(t *testing.T) { +	c := domain.New(100, time.Second) + +	blocks := []string{ +		"google.com", +		"google.co.uk", +		"pleroma.bad.host", +	} + +	loader := func() ([]string, error) { +		t.Log("load: returning blocked domains") +		return blocks, nil +	} + +	// Check a list of known blocked domains. +	for _, domain := range []string{ +		"google.com", +		"mail.google.com", +		"google.co.uk", +		"mail.google.co.uk", +		"pleroma.bad.host", +		"dev.pleroma.bad.host", +	} { +		t.Logf("checking domain is blocked: %s", domain) +		if b, _ := c.IsBlocked(domain, loader); !b { +			t.Errorf("domain should be blocked: %s", domain) +		} +	} + +	// Check a list of known unblocked domains. +	for _, domain := range []string{ +		"askjeeves.com", +		"ask-kim.co.uk", +		"google.ie", +		"mail.google.ie", +		"gts.bad.host", +		"mastodon.bad.host", +	} { +		t.Logf("checking domain isn't blocked: %s", domain) +		if b, _ := c.IsBlocked(domain, loader); b { +			t.Errorf("domain should not be blocked: %s", domain) +		} +	} + +	// Clear the cache +	c.Clear() + +	knownErr := errors.New("known error") + +	// Check that reload is actually performed and returns our error +	if _, err := c.IsBlocked("", func() ([]string, error) { +		t.Log("load: returning known error") +		return nil, knownErr +	}); !errors.Is(err, knownErr) { +		t.Errorf("is blocked did not return expected error: %v", err) +	} +} diff --git a/internal/cache/gts.go b/internal/cache/gts.go index 6083b8693..3fa25ddef 100644 --- a/internal/cache/gts.go +++ b/internal/cache/gts.go @@ -20,6 +20,7 @@ package cache  import (  	"codeberg.org/gruf/go-cache/v3/result" +	"github.com/superseriousbusiness/gotosocial/internal/cache/domain"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  ) @@ -41,8 +42,8 @@ type GTSCaches interface {  	// Block provides access to the gtsmodel Block (account) database cache.  	Block() *result.Cache[*gtsmodel.Block] -	// DomainBlock provides access to the gtsmodel DomainBlock database cache. -	DomainBlock() *result.Cache[*gtsmodel.DomainBlock] +	// DomainBlock provides access to the domain block database cache. +	DomainBlock() *domain.BlockCache  	// Emoji provides access to the gtsmodel Emoji database cache.  	Emoji() *result.Cache[*gtsmodel.Emoji] @@ -74,7 +75,7 @@ func NewGTS() GTSCaches {  type gtsCaches struct {  	account       *result.Cache[*gtsmodel.Account]  	block         *result.Cache[*gtsmodel.Block] -	domainBlock   *result.Cache[*gtsmodel.DomainBlock] +	domainBlock   *domain.BlockCache  	emoji         *result.Cache[*gtsmodel.Emoji]  	emojiCategory *result.Cache[*gtsmodel.EmojiCategory]  	mention       *result.Cache[*gtsmodel.Mention] @@ -151,7 +152,7 @@ func (c *gtsCaches) Block() *result.Cache[*gtsmodel.Block] {  	return c.block  } -func (c *gtsCaches) DomainBlock() *result.Cache[*gtsmodel.DomainBlock] { +func (c *gtsCaches) DomainBlock() *domain.BlockCache {  	return c.domainBlock  } @@ -212,14 +213,10 @@ func (c *gtsCaches) initBlock() {  }  func (c *gtsCaches) initDomainBlock() { -	c.domainBlock = result.NewSized([]result.Lookup{ -		{Name: "Domain"}, -	}, func(d1 *gtsmodel.DomainBlock) *gtsmodel.DomainBlock { -		d2 := new(gtsmodel.DomainBlock) -		*d2 = *d1 -		return d2 -	}, config.GetCacheGTSDomainBlockMaxSize()) -	c.domainBlock.SetTTL(config.GetCacheGTSDomainBlockTTL(), true) +	c.domainBlock = domain.New( +		config.GetCacheGTSDomainBlockMaxSize(), +		config.GetCacheGTSDomainBlockTTL(), +	)  }  func (c *gtsCaches) initEmoji() { diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index a5d9f61e2..5407f9656 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -50,46 +50,52 @@ func normalizeDomain(domain string) (out string, err error) {  func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error {  	var err error +	// Normalize the domain as punycode  	block.Domain, err = normalizeDomain(block.Domain)  	if err != nil {  		return err  	} -	return d.state.Caches.GTS.DomainBlock().Store(block, func() error { -		_, err := d.conn.NewInsert(). -			Model(block). -			Exec(ctx) +	// Attempt to store domain in DB +	if _, err := d.conn.NewInsert(). +		Model(block). +		Exec(ctx); err != nil {  		return d.conn.ProcessError(err) -	}) +	} + +	// Clear the domain block cache (for later reload) +	d.state.Caches.GTS.DomainBlock().Clear() + +	return nil  }  func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) {  	var err error +	// Normalize the domain as punycode  	domain, err = normalizeDomain(domain)  	if err != nil {  		return nil, err  	} -	return d.state.Caches.GTS.DomainBlock().Load("Domain", func() (*gtsmodel.DomainBlock, error) { -		// Check for easy case, domain referencing *us* -		if domain == "" || domain == config.GetAccountDomain() { -			return nil, db.ErrNoEntries -		} +	// Check for easy case, domain referencing *us* +	if domain == "" || domain == config.GetAccountDomain() || +		domain == config.GetHost() { +		return nil, db.ErrNoEntries +	} -		var block gtsmodel.DomainBlock +	var block gtsmodel.DomainBlock -		q := d.conn. -			NewSelect(). -			Model(&block). -			Where("? = ?", bun.Ident("domain_block.domain"), domain). -			Limit(1) -		if err := q.Scan(ctx); err != nil { -			return nil, d.conn.ProcessError(err) -		} +	// Look for block matching domain in DB +	q := d.conn. +		NewSelect(). +		Model(&block). +		Where("? = ?", bun.Ident("domain_block.domain"), domain) +	if err := q.Scan(ctx); err != nil { +		return nil, d.conn.ProcessError(err) +	} -		return &block, nil -	}, domain) +	return &block, nil  }  func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { @@ -108,18 +114,39 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro  		return d.conn.ProcessError(err)  	} -	// Clear domain from cache -	d.state.Caches.GTS.DomainBlock().Invalidate(domain) +	// Clear the domain block cache (for later reload) +	d.state.Caches.GTS.DomainBlock().Clear()  	return nil  }  func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { -	block, err := d.GetDomainBlock(ctx, domain) -	if err == nil || err == db.ErrNoEntries { -		return (block != nil), nil +	// Normalize the domain as punycode +	domain, err := normalizeDomain(domain) +	if err != nil { +		return false, err  	} -	return false, err + +	// Check for easy case, domain referencing *us* +	if domain == "" || domain == config.GetAccountDomain() || +		domain == config.GetHost() { +		return false, nil +	} + +	// Check the cache for a domain block (hydrating the cache with callback if necessary) +	return d.state.Caches.GTS.DomainBlock().IsBlocked(domain, func() ([]string, error) { +		var domains []string + +		// Scan list of all blocked domains from DB +		q := d.conn.NewSelect(). +			Table("domain_blocks"). +			Column("domain") +		if err := q.Scan(ctx, &domains); err != nil { +			return nil, d.conn.ProcessError(err) +		} + +		return domains, nil +	})  }  func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go index 41a73ff80..8091e6585 100644 --- a/internal/db/bundb/domain_test.go +++ b/internal/db/bundb/domain_test.go @@ -56,6 +56,38 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() {  	suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second)  } +func (suite *DomainTestSuite) TestIsDomainBlockedWildcard() { +	ctx := context.Background() + +	domainBlock := >smodel.DomainBlock{ +		ID:                 "01G204214Y9TNJEBX39C7G88SW", +		Domain:             "bad.apples", +		CreatedByAccountID: suite.testAccounts["admin_account"].ID, +		CreatedByAccount:   suite.testAccounts["admin_account"], +	} + +	// no domain block exists for the given domain yet +	blocked, err := suite.db.IsDomainBlocked(ctx, domainBlock.Domain) +	suite.NoError(err) +	suite.False(blocked) + +	err = suite.db.CreateDomainBlock(ctx, domainBlock) +	suite.NoError(err) + +	// Start with the base block domain +	domain := domainBlock.Domain + +	for _, part := range []string{"extra", "domain", "parts"} { +		// Prepend the next domain part +		domain = part + "." + domain + +		// Check that domain block is wildcarded for this subdomain +		blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain) +		suite.NoError(err) +		suite.True(blocked) +	} +} +  func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII() {  	ctx := context.Background()  | 
