diff options
| -rw-r--r-- | internal/db/bundb/domain.go | 43 | ||||
| -rw-r--r-- | internal/db/bundb/domain_test.go | 72 | 
2 files changed, 103 insertions, 12 deletions
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 0d67837d7..5d262c676 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -28,6 +28,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"golang.org/x/net/idna"  )  type domainDB struct { @@ -35,15 +36,28 @@ type domainDB struct {  	cache *cache.DomainBlockCache  } +// normalizeDomain converts the given domain to lowercase +// then to punycode (for international domain names). +// +// Returns the resulting domain or an error if the +// punycode conversion fails. +func normalizeDomain(domain string) (out string, err error) { +	out = strings.ToLower(domain) +	out, err = idna.ToASCII(out) +	return out, err +} +  func (d *domainDB) CreateDomainBlock(ctx context.Context, block gtsmodel.DomainBlock) db.Error { -	// Normalize to lowercase -	block.Domain = strings.ToLower(block.Domain) +	domain, err := normalizeDomain(block.Domain) +	if err != nil { +		return err +	} +	block.Domain = domain  	// Attempt to insert new domain block -	_, err := d.conn.NewInsert(). +	if _, err := d.conn.NewInsert().  		Model(&block). -		Exec(ctx, &block) -	if err != nil { +		Exec(ctx, &block); err != nil {  		return d.conn.ProcessError(err)  	} @@ -54,8 +68,11 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block gtsmodel.DomainB  }  func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { -	// Normalize to lowercase -	domain = strings.ToLower(domain) +	var err error +	domain, err = normalizeDomain(domain) +	if err != nil { +		return nil, err +	}  	// Check for easy case, domain referencing *us*  	if domain == "" || domain == config.GetAccountDomain() { @@ -100,15 +117,17 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel  }  func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { -	// Normalize to lowercase -	domain = strings.ToLower(domain) +	var err error +	domain, err = normalizeDomain(domain) +	if err != nil { +		return err +	}  	// Attempt to delete domain block -	_, err := d.conn.NewDelete(). +	if _, err := d.conn.NewDelete().  		Model((*gtsmodel.DomainBlock)(nil)).  		Where("domain = ?", domain). -		Exec(ctx) -	if err != nil { +		Exec(ctx); err != nil {  		return d.conn.ProcessError(err)  	} diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go index b326236ad..48c4a7798 100644 --- a/internal/db/bundb/domain_test.go +++ b/internal/db/bundb/domain_test.go @@ -59,6 +59,78 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() {  	suite.True(blocked)  } +func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII() { +	ctx := context.Background() + +	now := time.Now() + +	domainBlock := >smodel.DomainBlock{ +		ID:                 "01G204214Y9TNJEBX39C7G88SW", +		Domain:             "xn--80aaa1bbb1h.com", +		CreatedAt:          now, +		UpdatedAt:          now, +		CreatedByAccountID: suite.testAccounts["admin_account"].ID, +		CreatedByAccount:   suite.testAccounts["admin_account"], +	} + +	// no domain block exists for the given domain yet +	blocked, err := suite.db.IsDomainBlocked(ctx, "какашка.com") +	suite.NoError(err) +	suite.False(blocked) + +	blocked, err = suite.db.IsDomainBlocked(ctx, "xn--80aaa1bbb1h.com") +	suite.NoError(err) +	suite.False(blocked) + +	err = suite.db.CreateDomainBlock(ctx, *domainBlock) +	suite.NoError(err) + +	// domain block now exists +	blocked, err = suite.db.IsDomainBlocked(ctx, "какашка.com") +	suite.NoError(err) +	suite.True(blocked) + +	blocked, err = suite.db.IsDomainBlocked(ctx, "xn--80aaa1bbb1h.com") +	suite.NoError(err) +	suite.True(blocked) +} + +func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII2() { +	ctx := context.Background() + +	now := time.Now() + +	domainBlock := >smodel.DomainBlock{ +		ID:                 "01G204214Y9TNJEBX39C7G88SW", +		Domain:             "какашка.com", +		CreatedAt:          now, +		UpdatedAt:          now, +		CreatedByAccountID: suite.testAccounts["admin_account"].ID, +		CreatedByAccount:   suite.testAccounts["admin_account"], +	} + +	// no domain block exists for the given domain yet +	blocked, err := suite.db.IsDomainBlocked(ctx, "какашка.com") +	suite.NoError(err) +	suite.False(blocked) + +	blocked, err = suite.db.IsDomainBlocked(ctx, "xn--80aaa1bbb1h.com") +	suite.NoError(err) +	suite.False(blocked) + +	err = suite.db.CreateDomainBlock(ctx, *domainBlock) +	suite.NoError(err) + +	// domain block now exists +	blocked, err = suite.db.IsDomainBlocked(ctx, "какашка.com") +	suite.NoError(err) +	suite.True(blocked) + +	blocked, err = suite.db.IsDomainBlocked(ctx, "xn--80aaa1bbb1h.com") +	suite.NoError(err) +	suite.True(blocked) +} +  func TestDomainTestSuite(t *testing.T) {  	suite.Run(t, new(DomainTestSuite))  }  | 
