diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/domain.go | 148 | ||||
| -rw-r--r-- | internal/db/bundb/domain_test.go | 53 | ||||
| -rw-r--r-- | internal/db/bundb/migrations/20230908083121_allowlist.go.go | 62 | 
3 files changed, 261 insertions, 2 deletions
| diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index c989d4fe4..dd626bc0a 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -23,6 +23,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/util" @@ -34,6 +35,102 @@ type domainDB struct {  	state *state.State  } +func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error { +	// Normalize the domain as punycode +	var err error +	allow.Domain, err = util.Punify(allow.Domain) +	if err != nil { +		return err +	} + +	// Attempt to store domain allow in DB +	if _, err := d.db.NewInsert(). +		Model(allow). +		Exec(ctx); err != nil { +		return err +	} + +	// Clear the domain allow cache (for later reload) +	d.state.Caches.GTS.DomainAllow().Clear() + +	return nil +} + +func (d *domainDB) GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error) { +	// Normalize the domain as punycode +	domain, err := util.Punify(domain) +	if err != nil { +		return nil, err +	} + +	// Check for easy case, domain referencing *us* +	if domain == "" || domain == config.GetAccountDomain() || +		domain == config.GetHost() { +		return nil, db.ErrNoEntries +	} + +	var allow gtsmodel.DomainAllow + +	// Look for allow matching domain in DB +	q := d.db. +		NewSelect(). +		Model(&allow). +		Where("? = ?", bun.Ident("domain_allow.domain"), domain) +	if err := q.Scan(ctx); err != nil { +		return nil, err +	} + +	return &allow, nil +} + +func (d *domainDB) GetDomainAllows(ctx context.Context) ([]*gtsmodel.DomainAllow, error) { +	allows := []*gtsmodel.DomainAllow{} + +	if err := d.db. +		NewSelect(). +		Model(&allows). +		Scan(ctx); err != nil { +		return nil, err +	} + +	return allows, nil +} + +func (d *domainDB) GetDomainAllowByID(ctx context.Context, id string) (*gtsmodel.DomainAllow, error) { +	var allow gtsmodel.DomainAllow + +	q := d.db. +		NewSelect(). +		Model(&allow). +		Where("? = ?", bun.Ident("domain_allow.id"), id) +	if err := q.Scan(ctx); err != nil { +		return nil, err +	} + +	return &allow, nil +} + +func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error { +	// Normalize the domain as punycode +	domain, err := util.Punify(domain) +	if err != nil { +		return err +	} + +	// Attempt to delete domain allow +	if _, err := d.db.NewDelete(). +		Model((*gtsmodel.DomainAllow)(nil)). +		Where("? = ?", bun.Ident("domain_allow.domain"), domain). +		Exec(ctx); err != nil { +		return err +	} + +	// Clear the domain allow cache (for later reload) +	d.state.Caches.GTS.DomainAllow().Clear() + +	return nil +} +  func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error {  	// Normalize the domain as punycode  	var err error @@ -137,14 +234,32 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er  		return false, err  	} -	// Check for easy case, domain referencing *us* +	// Domain referencing *us* cannot be blocked.  	if domain == "" || domain == config.GetAccountDomain() ||  		domain == config.GetHost() {  		return false, nil  	} +	// Check the cache for an explicit domain allow (hydrating the cache with callback if necessary). +	explicitAllow, err := d.state.Caches.GTS.DomainAllow().Matches(domain, func() ([]string, error) { +		var domains []string + +		// Scan list of all explicitly allowed domains from DB +		q := d.db.NewSelect(). +			Table("domain_allows"). +			Column("domain") +		if err := q.Scan(ctx, &domains); err != nil { +			return nil, err +		} + +		return domains, nil +	}) +	if err != nil { +		return false, err +	} +  	// Check the cache for a domain block (hydrating the cache with callback if necessary) -	return d.state.Caches.GTS.DomainBlock().IsBlocked(domain, func() ([]string, error) { +	explicitBlock, err := d.state.Caches.GTS.DomainBlock().Matches(domain, func() ([]string, error) {  		var domains []string  		// Scan list of all blocked domains from DB @@ -157,6 +272,35 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er  		return domains, nil  	}) +	if err != nil { +		return false, err +	} + +	// Calculate if blocked +	// based on federation mode. +	switch mode := config.GetInstanceFederationMode(); mode { + +	case config.InstanceFederationModeBlocklist: +		// Blocklist/default mode: explicit allow +		// takes precedence over explicit block. +		// +		// Domains that have neither block +		// or allow entries are allowed. +		return !(explicitAllow || !explicitBlock), nil + +	case config.InstanceFederationModeAllowlist: +		// Allowlist mode: explicit block takes +		// precedence over explicit allow. +		// +		// Domains that have neither block +		// or allow entries are blocked. +		return (explicitBlock || !explicitAllow), nil + +	default: +		// This should never happen but account +		// for it anyway to make the code tidier. +		return false, gtserror.Newf("unrecognized federation mode: %s", mode) +	}  }  func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, error) { diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go index e4e199fa1..ff687cf59 100644 --- a/internal/db/bundb/domain_test.go +++ b/internal/db/bundb/domain_test.go @@ -55,6 +55,59 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() {  	suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second)  } +func (suite *DomainTestSuite) TestIsDomainBlockedWithAllow() { +	ctx := context.Background() + +	domainBlock := >smodel.DomainBlock{ +		ID:                 "01G204214Y9TNJEBX39C7G88SW", +		Domain:             "some.bad.apples", +		CreatedByAccountID: suite.testAccounts["admin_account"].ID, +		CreatedByAccount:   suite.testAccounts["admin_account"], +	} + +	// no domain block exists for the given domain yet +	blocked, err := suite.db.IsDomainBlocked(ctx, domainBlock.Domain) +	if err != nil { +		suite.FailNow(err.Error()) +	} + +	suite.False(blocked) + +	// Block this domain. +	if err := suite.db.CreateDomainBlock(ctx, domainBlock); err != nil { +		suite.FailNow(err.Error()) +	} + +	// domain block now exists +	blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain) +	if err != nil { +		suite.FailNow(err.Error()) +	} + +	suite.True(blocked) +	suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second) + +	// Explicitly allow this domain. +	domainAllow := >smodel.DomainAllow{ +		ID:                 "01H8KY9MJQFWE712EG3VN02Y3J", +		Domain:             "some.bad.apples", +		CreatedByAccountID: suite.testAccounts["admin_account"].ID, +		CreatedByAccount:   suite.testAccounts["admin_account"], +	} + +	if err := suite.db.CreateDomainAllow(ctx, domainAllow); err != nil { +		suite.FailNow(err.Error()) +	} + +	// Domain allow now exists +	blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain) +	if err != nil { +		suite.FailNow(err.Error()) +	} + +	suite.False(blocked) +} +  func (suite *DomainTestSuite) TestIsDomainBlockedWildcard() {  	ctx := context.Background() diff --git a/internal/db/bundb/migrations/20230908083121_allowlist.go.go b/internal/db/bundb/migrations/20230908083121_allowlist.go.go new file mode 100644 index 000000000..2d86f8c03 --- /dev/null +++ b/internal/db/bundb/migrations/20230908083121_allowlist.go.go @@ -0,0 +1,62 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( +	"context" + +	gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun" +) + +func init() { +	up := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			// Create domain allow. +			if _, err := tx. +				NewCreateTable(). +				Model(>smodel.DomainAllow{}). +				IfNotExists(). +				Exec(ctx); err != nil { +				return err +			} + +			// Index domain allow. +			if _, err := tx. +				NewCreateIndex(). +				Table("domain_allows"). +				Index("domain_allows_domain_idx"). +				Column("domain"). +				Exec(ctx); err != nil { +				return err +			} + +			return nil +		}) +	} + +	down := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			return nil +		}) +	} + +	if err := Migrations.Register(up, down); err != nil { +		panic(err) +	} +} | 
