diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/domain.go | 148 | ||||
-rw-r--r-- | internal/db/bundb/domain_test.go | 53 | ||||
-rw-r--r-- | internal/db/bundb/migrations/20230908083121_allowlist.go.go | 62 | ||||
-rw-r--r-- | internal/db/domain.go | 34 |
4 files changed, 291 insertions, 6 deletions
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index c989d4fe4..dd626bc0a 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -23,6 +23,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/util" @@ -34,6 +35,102 @@ type domainDB struct { state *state.State } +func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error { + // Normalize the domain as punycode + var err error + allow.Domain, err = util.Punify(allow.Domain) + if err != nil { + return err + } + + // Attempt to store domain allow in DB + if _, err := d.db.NewInsert(). + Model(allow). + Exec(ctx); err != nil { + return err + } + + // Clear the domain allow cache (for later reload) + d.state.Caches.GTS.DomainAllow().Clear() + + return nil +} + +func (d *domainDB) GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error) { + // Normalize the domain as punycode + domain, err := util.Punify(domain) + if err != nil { + return nil, err + } + + // Check for easy case, domain referencing *us* + if domain == "" || domain == config.GetAccountDomain() || + domain == config.GetHost() { + return nil, db.ErrNoEntries + } + + var allow gtsmodel.DomainAllow + + // Look for allow matching domain in DB + q := d.db. + NewSelect(). + Model(&allow). + Where("? = ?", bun.Ident("domain_allow.domain"), domain) + if err := q.Scan(ctx); err != nil { + return nil, err + } + + return &allow, nil +} + +func (d *domainDB) GetDomainAllows(ctx context.Context) ([]*gtsmodel.DomainAllow, error) { + allows := []*gtsmodel.DomainAllow{} + + if err := d.db. + NewSelect(). + Model(&allows). + Scan(ctx); err != nil { + return nil, err + } + + return allows, nil +} + +func (d *domainDB) GetDomainAllowByID(ctx context.Context, id string) (*gtsmodel.DomainAllow, error) { + var allow gtsmodel.DomainAllow + + q := d.db. + NewSelect(). + Model(&allow). + Where("? = ?", bun.Ident("domain_allow.id"), id) + if err := q.Scan(ctx); err != nil { + return nil, err + } + + return &allow, nil +} + +func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error { + // Normalize the domain as punycode + domain, err := util.Punify(domain) + if err != nil { + return err + } + + // Attempt to delete domain allow + if _, err := d.db.NewDelete(). + Model((*gtsmodel.DomainAllow)(nil)). + Where("? = ?", bun.Ident("domain_allow.domain"), domain). + Exec(ctx); err != nil { + return err + } + + // Clear the domain allow cache (for later reload) + d.state.Caches.GTS.DomainAllow().Clear() + + return nil +} + func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error { // Normalize the domain as punycode var err error @@ -137,14 +234,32 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er return false, err } - // Check for easy case, domain referencing *us* + // Domain referencing *us* cannot be blocked. if domain == "" || domain == config.GetAccountDomain() || domain == config.GetHost() { return false, nil } + // Check the cache for an explicit domain allow (hydrating the cache with callback if necessary). + explicitAllow, err := d.state.Caches.GTS.DomainAllow().Matches(domain, func() ([]string, error) { + var domains []string + + // Scan list of all explicitly allowed domains from DB + q := d.db.NewSelect(). + Table("domain_allows"). + Column("domain") + if err := q.Scan(ctx, &domains); err != nil { + return nil, err + } + + return domains, nil + }) + if err != nil { + return false, err + } + // Check the cache for a domain block (hydrating the cache with callback if necessary) - return d.state.Caches.GTS.DomainBlock().IsBlocked(domain, func() ([]string, error) { + explicitBlock, err := d.state.Caches.GTS.DomainBlock().Matches(domain, func() ([]string, error) { var domains []string // Scan list of all blocked domains from DB @@ -157,6 +272,35 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er return domains, nil }) + if err != nil { + return false, err + } + + // Calculate if blocked + // based on federation mode. + switch mode := config.GetInstanceFederationMode(); mode { + + case config.InstanceFederationModeBlocklist: + // Blocklist/default mode: explicit allow + // takes precedence over explicit block. + // + // Domains that have neither block + // or allow entries are allowed. + return !(explicitAllow || !explicitBlock), nil + + case config.InstanceFederationModeAllowlist: + // Allowlist mode: explicit block takes + // precedence over explicit allow. + // + // Domains that have neither block + // or allow entries are blocked. + return (explicitBlock || !explicitAllow), nil + + default: + // This should never happen but account + // for it anyway to make the code tidier. + return false, gtserror.Newf("unrecognized federation mode: %s", mode) + } } func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, error) { diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go index e4e199fa1..ff687cf59 100644 --- a/internal/db/bundb/domain_test.go +++ b/internal/db/bundb/domain_test.go @@ -55,6 +55,59 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() { suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second) } +func (suite *DomainTestSuite) TestIsDomainBlockedWithAllow() { + ctx := context.Background() + + domainBlock := >smodel.DomainBlock{ + ID: "01G204214Y9TNJEBX39C7G88SW", + Domain: "some.bad.apples", + CreatedByAccountID: suite.testAccounts["admin_account"].ID, + CreatedByAccount: suite.testAccounts["admin_account"], + } + + // no domain block exists for the given domain yet + blocked, err := suite.db.IsDomainBlocked(ctx, domainBlock.Domain) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.False(blocked) + + // Block this domain. + if err := suite.db.CreateDomainBlock(ctx, domainBlock); err != nil { + suite.FailNow(err.Error()) + } + + // domain block now exists + blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.True(blocked) + suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second) + + // Explicitly allow this domain. + domainAllow := >smodel.DomainAllow{ + ID: "01H8KY9MJQFWE712EG3VN02Y3J", + Domain: "some.bad.apples", + CreatedByAccountID: suite.testAccounts["admin_account"].ID, + CreatedByAccount: suite.testAccounts["admin_account"], + } + + if err := suite.db.CreateDomainAllow(ctx, domainAllow); err != nil { + suite.FailNow(err.Error()) + } + + // Domain allow now exists + blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.False(blocked) +} + func (suite *DomainTestSuite) TestIsDomainBlockedWildcard() { ctx := context.Background() diff --git a/internal/db/bundb/migrations/20230908083121_allowlist.go.go b/internal/db/bundb/migrations/20230908083121_allowlist.go.go new file mode 100644 index 000000000..2d86f8c03 --- /dev/null +++ b/internal/db/bundb/migrations/20230908083121_allowlist.go.go @@ -0,0 +1,62 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( + "context" + + gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Create domain allow. + if _, err := tx. + NewCreateTable(). + Model(>smodel.DomainAllow{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Index domain allow. + if _, err := tx. + NewCreateIndex(). + Table("domain_allows"). + Index("domain_allows_domain_idx"). + Column("domain"). + Exec(ctx); err != nil { + return err + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/domain.go b/internal/db/domain.go index 740ccefe6..3f7803d62 100644 --- a/internal/db/domain.go +++ b/internal/db/domain.go @@ -26,6 +26,25 @@ import ( // Domain contains DB functions related to domains and domain blocks. type Domain interface { + /* + Block/allow storage + retrieval functions. + */ + + // CreateDomainAllow puts the given instance-level domain allow into the database. + CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error + + // GetDomainAllow returns one instance-level domain allow with the given domain, if it exists. + GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error) + + // GetDomainAllowByID returns one instance-level domain allow with the given id, if it exists. + GetDomainAllowByID(ctx context.Context, id string) (*gtsmodel.DomainAllow, error) + + // GetDomainAllows returns all instance-level domain allows currently enforced by this instance. + GetDomainAllows(ctx context.Context) ([]*gtsmodel.DomainAllow, error) + + // DeleteDomainAllow deletes an instance-level domain allow with the given domain, if it exists. + DeleteDomainAllow(ctx context.Context, domain string) error + // CreateDomainBlock puts the given instance-level domain block into the database. CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error @@ -41,15 +60,22 @@ type Domain interface { // DeleteDomainBlock deletes an instance-level domain block with the given domain, if it exists. DeleteDomainBlock(ctx context.Context, domain string) error - // IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`). + /* + Block/allow checking functions. + */ + + // IsDomainBlocked checks if domain is blocked, accounting for both explicit allows and blocks. + // Will check allows first, so an allowed domain will always return false, even if it's also blocked. IsDomainBlocked(ctx context.Context, domain string) (bool, error) - // AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found. + // AreDomainsBlocked calls IsDomainBlocked for each domain. + // Will return true if even one of the given domains is blocked. AreDomainsBlocked(ctx context.Context, domains []string) (bool, error) - // IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`). + // IsURIBlocked calls IsDomainBlocked for the host of the given URI. IsURIBlocked(ctx context.Context, uri *url.URL) (bool, error) - // AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found. + // AreURIsBlocked calls IsURIBlocked for each URI. + // Will return true if even one of the given URIs is blocked. AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, error) } |