diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/domain.go | 81 | ||||
| -rw-r--r-- | internal/db/bundb/domain_test.go | 32 | 
2 files changed, 86 insertions, 27 deletions
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index a5d9f61e2..5407f9656 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -50,46 +50,52 @@ func normalizeDomain(domain string) (out string, err error) {  func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error {  	var err error +	// Normalize the domain as punycode  	block.Domain, err = normalizeDomain(block.Domain)  	if err != nil {  		return err  	} -	return d.state.Caches.GTS.DomainBlock().Store(block, func() error { -		_, err := d.conn.NewInsert(). -			Model(block). -			Exec(ctx) +	// Attempt to store domain in DB +	if _, err := d.conn.NewInsert(). +		Model(block). +		Exec(ctx); err != nil {  		return d.conn.ProcessError(err) -	}) +	} + +	// Clear the domain block cache (for later reload) +	d.state.Caches.GTS.DomainBlock().Clear() + +	return nil  }  func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) {  	var err error +	// Normalize the domain as punycode  	domain, err = normalizeDomain(domain)  	if err != nil {  		return nil, err  	} -	return d.state.Caches.GTS.DomainBlock().Load("Domain", func() (*gtsmodel.DomainBlock, error) { -		// Check for easy case, domain referencing *us* -		if domain == "" || domain == config.GetAccountDomain() { -			return nil, db.ErrNoEntries -		} +	// Check for easy case, domain referencing *us* +	if domain == "" || domain == config.GetAccountDomain() || +		domain == config.GetHost() { +		return nil, db.ErrNoEntries +	} -		var block gtsmodel.DomainBlock +	var block gtsmodel.DomainBlock -		q := d.conn. -			NewSelect(). -			Model(&block). -			Where("? = ?", bun.Ident("domain_block.domain"), domain). -			Limit(1) -		if err := q.Scan(ctx); err != nil { -			return nil, d.conn.ProcessError(err) -		} +	// Look for block matching domain in DB +	q := d.conn. +		NewSelect(). +		Model(&block). +		Where("? = ?", bun.Ident("domain_block.domain"), domain) +	if err := q.Scan(ctx); err != nil { +		return nil, d.conn.ProcessError(err) +	} -		return &block, nil -	}, domain) +	return &block, nil  }  func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { @@ -108,18 +114,39 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro  		return d.conn.ProcessError(err)  	} -	// Clear domain from cache -	d.state.Caches.GTS.DomainBlock().Invalidate(domain) +	// Clear the domain block cache (for later reload) +	d.state.Caches.GTS.DomainBlock().Clear()  	return nil  }  func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { -	block, err := d.GetDomainBlock(ctx, domain) -	if err == nil || err == db.ErrNoEntries { -		return (block != nil), nil +	// Normalize the domain as punycode +	domain, err := normalizeDomain(domain) +	if err != nil { +		return false, err  	} -	return false, err + +	// Check for easy case, domain referencing *us* +	if domain == "" || domain == config.GetAccountDomain() || +		domain == config.GetHost() { +		return false, nil +	} + +	// Check the cache for a domain block (hydrating the cache with callback if necessary) +	return d.state.Caches.GTS.DomainBlock().IsBlocked(domain, func() ([]string, error) { +		var domains []string + +		// Scan list of all blocked domains from DB +		q := d.conn.NewSelect(). +			Table("domain_blocks"). +			Column("domain") +		if err := q.Scan(ctx, &domains); err != nil { +			return nil, d.conn.ProcessError(err) +		} + +		return domains, nil +	})  }  func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go index 41a73ff80..8091e6585 100644 --- a/internal/db/bundb/domain_test.go +++ b/internal/db/bundb/domain_test.go @@ -56,6 +56,38 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() {  	suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second)  } +func (suite *DomainTestSuite) TestIsDomainBlockedWildcard() { +	ctx := context.Background() + +	domainBlock := >smodel.DomainBlock{ +		ID:                 "01G204214Y9TNJEBX39C7G88SW", +		Domain:             "bad.apples", +		CreatedByAccountID: suite.testAccounts["admin_account"].ID, +		CreatedByAccount:   suite.testAccounts["admin_account"], +	} + +	// no domain block exists for the given domain yet +	blocked, err := suite.db.IsDomainBlocked(ctx, domainBlock.Domain) +	suite.NoError(err) +	suite.False(blocked) + +	err = suite.db.CreateDomainBlock(ctx, domainBlock) +	suite.NoError(err) + +	// Start with the base block domain +	domain := domainBlock.Domain + +	for _, part := range []string{"extra", "domain", "parts"} { +		// Prepend the next domain part +		domain = part + "." + domain + +		// Check that domain block is wildcarded for this subdomain +		blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain) +		suite.NoError(err) +		suite.True(blocked) +	} +} +  func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII() {  	ctx := context.Background()  | 
