diff options
Diffstat (limited to 'internal/db/bundb/domain.go')
-rw-r--r-- | internal/db/bundb/domain.go | 148 |
1 files changed, 146 insertions, 2 deletions
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index c989d4fe4..dd626bc0a 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -23,6 +23,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/util" @@ -34,6 +35,102 @@ type domainDB struct { state *state.State } +func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error { + // Normalize the domain as punycode + var err error + allow.Domain, err = util.Punify(allow.Domain) + if err != nil { + return err + } + + // Attempt to store domain allow in DB + if _, err := d.db.NewInsert(). + Model(allow). + Exec(ctx); err != nil { + return err + } + + // Clear the domain allow cache (for later reload) + d.state.Caches.GTS.DomainAllow().Clear() + + return nil +} + +func (d *domainDB) GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error) { + // Normalize the domain as punycode + domain, err := util.Punify(domain) + if err != nil { + return nil, err + } + + // Check for easy case, domain referencing *us* + if domain == "" || domain == config.GetAccountDomain() || + domain == config.GetHost() { + return nil, db.ErrNoEntries + } + + var allow gtsmodel.DomainAllow + + // Look for allow matching domain in DB + q := d.db. + NewSelect(). + Model(&allow). + Where("? = ?", bun.Ident("domain_allow.domain"), domain) + if err := q.Scan(ctx); err != nil { + return nil, err + } + + return &allow, nil +} + +func (d *domainDB) GetDomainAllows(ctx context.Context) ([]*gtsmodel.DomainAllow, error) { + allows := []*gtsmodel.DomainAllow{} + + if err := d.db. + NewSelect(). + Model(&allows). + Scan(ctx); err != nil { + return nil, err + } + + return allows, nil +} + +func (d *domainDB) GetDomainAllowByID(ctx context.Context, id string) (*gtsmodel.DomainAllow, error) { + var allow gtsmodel.DomainAllow + + q := d.db. + NewSelect(). + Model(&allow). + Where("? = ?", bun.Ident("domain_allow.id"), id) + if err := q.Scan(ctx); err != nil { + return nil, err + } + + return &allow, nil +} + +func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error { + // Normalize the domain as punycode + domain, err := util.Punify(domain) + if err != nil { + return err + } + + // Attempt to delete domain allow + if _, err := d.db.NewDelete(). + Model((*gtsmodel.DomainAllow)(nil)). + Where("? = ?", bun.Ident("domain_allow.domain"), domain). + Exec(ctx); err != nil { + return err + } + + // Clear the domain allow cache (for later reload) + d.state.Caches.GTS.DomainAllow().Clear() + + return nil +} + func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error { // Normalize the domain as punycode var err error @@ -137,14 +234,32 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er return false, err } - // Check for easy case, domain referencing *us* + // Domain referencing *us* cannot be blocked. if domain == "" || domain == config.GetAccountDomain() || domain == config.GetHost() { return false, nil } + // Check the cache for an explicit domain allow (hydrating the cache with callback if necessary). + explicitAllow, err := d.state.Caches.GTS.DomainAllow().Matches(domain, func() ([]string, error) { + var domains []string + + // Scan list of all explicitly allowed domains from DB + q := d.db.NewSelect(). + Table("domain_allows"). + Column("domain") + if err := q.Scan(ctx, &domains); err != nil { + return nil, err + } + + return domains, nil + }) + if err != nil { + return false, err + } + // Check the cache for a domain block (hydrating the cache with callback if necessary) - return d.state.Caches.GTS.DomainBlock().IsBlocked(domain, func() ([]string, error) { + explicitBlock, err := d.state.Caches.GTS.DomainBlock().Matches(domain, func() ([]string, error) { var domains []string // Scan list of all blocked domains from DB @@ -157,6 +272,35 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er return domains, nil }) + if err != nil { + return false, err + } + + // Calculate if blocked + // based on federation mode. + switch mode := config.GetInstanceFederationMode(); mode { + + case config.InstanceFederationModeBlocklist: + // Blocklist/default mode: explicit allow + // takes precedence over explicit block. + // + // Domains that have neither block + // or allow entries are allowed. + return !(explicitAllow || !explicitBlock), nil + + case config.InstanceFederationModeAllowlist: + // Allowlist mode: explicit block takes + // precedence over explicit allow. + // + // Domains that have neither block + // or allow entries are blocked. + return (explicitBlock || !explicitAllow), nil + + default: + // This should never happen but account + // for it anyway to make the code tidier. + return false, gtserror.Newf("unrecognized federation mode: %s", mode) + } } func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, error) { |