summaryrefslogtreecommitdiff
path: root/internal/db/bundb/domain.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/domain.go')
-rw-r--r--internal/db/bundb/domain.go148
1 files changed, 146 insertions, 2 deletions
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
index c989d4fe4..dd626bc0a 100644
--- a/internal/db/bundb/domain.go
+++ b/internal/db/bundb/domain.go
@@ -23,6 +23,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
@@ -34,6 +35,102 @@ type domainDB struct {
state *state.State
}
+func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error {
+ // Normalize the domain as punycode
+ var err error
+ allow.Domain, err = util.Punify(allow.Domain)
+ if err != nil {
+ return err
+ }
+
+ // Attempt to store domain allow in DB
+ if _, err := d.db.NewInsert().
+ Model(allow).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Clear the domain allow cache (for later reload)
+ d.state.Caches.GTS.DomainAllow().Clear()
+
+ return nil
+}
+
+func (d *domainDB) GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error) {
+ // Normalize the domain as punycode
+ domain, err := util.Punify(domain)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check for easy case, domain referencing *us*
+ if domain == "" || domain == config.GetAccountDomain() ||
+ domain == config.GetHost() {
+ return nil, db.ErrNoEntries
+ }
+
+ var allow gtsmodel.DomainAllow
+
+ // Look for allow matching domain in DB
+ q := d.db.
+ NewSelect().
+ Model(&allow).
+ Where("? = ?", bun.Ident("domain_allow.domain"), domain)
+ if err := q.Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return &allow, nil
+}
+
+func (d *domainDB) GetDomainAllows(ctx context.Context) ([]*gtsmodel.DomainAllow, error) {
+ allows := []*gtsmodel.DomainAllow{}
+
+ if err := d.db.
+ NewSelect().
+ Model(&allows).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return allows, nil
+}
+
+func (d *domainDB) GetDomainAllowByID(ctx context.Context, id string) (*gtsmodel.DomainAllow, error) {
+ var allow gtsmodel.DomainAllow
+
+ q := d.db.
+ NewSelect().
+ Model(&allow).
+ Where("? = ?", bun.Ident("domain_allow.id"), id)
+ if err := q.Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return &allow, nil
+}
+
+func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error {
+ // Normalize the domain as punycode
+ domain, err := util.Punify(domain)
+ if err != nil {
+ return err
+ }
+
+ // Attempt to delete domain allow
+ if _, err := d.db.NewDelete().
+ Model((*gtsmodel.DomainAllow)(nil)).
+ Where("? = ?", bun.Ident("domain_allow.domain"), domain).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Clear the domain allow cache (for later reload)
+ d.state.Caches.GTS.DomainAllow().Clear()
+
+ return nil
+}
+
func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error {
// Normalize the domain as punycode
var err error
@@ -137,14 +234,32 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er
return false, err
}
- // Check for easy case, domain referencing *us*
+ // Domain referencing *us* cannot be blocked.
if domain == "" || domain == config.GetAccountDomain() ||
domain == config.GetHost() {
return false, nil
}
+ // Check the cache for an explicit domain allow (hydrating the cache with callback if necessary).
+ explicitAllow, err := d.state.Caches.GTS.DomainAllow().Matches(domain, func() ([]string, error) {
+ var domains []string
+
+ // Scan list of all explicitly allowed domains from DB
+ q := d.db.NewSelect().
+ Table("domain_allows").
+ Column("domain")
+ if err := q.Scan(ctx, &domains); err != nil {
+ return nil, err
+ }
+
+ return domains, nil
+ })
+ if err != nil {
+ return false, err
+ }
+
// Check the cache for a domain block (hydrating the cache with callback if necessary)
- return d.state.Caches.GTS.DomainBlock().IsBlocked(domain, func() ([]string, error) {
+ explicitBlock, err := d.state.Caches.GTS.DomainBlock().Matches(domain, func() ([]string, error) {
var domains []string
// Scan list of all blocked domains from DB
@@ -157,6 +272,35 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er
return domains, nil
})
+ if err != nil {
+ return false, err
+ }
+
+ // Calculate if blocked
+ // based on federation mode.
+ switch mode := config.GetInstanceFederationMode(); mode {
+
+ case config.InstanceFederationModeBlocklist:
+ // Blocklist/default mode: explicit allow
+ // takes precedence over explicit block.
+ //
+ // Domains that have neither block
+ // or allow entries are allowed.
+ return !(explicitAllow || !explicitBlock), nil
+
+ case config.InstanceFederationModeAllowlist:
+ // Allowlist mode: explicit block takes
+ // precedence over explicit allow.
+ //
+ // Domains that have neither block
+ // or allow entries are blocked.
+ return (explicitBlock || !explicitAllow), nil
+
+ default:
+ // This should never happen but account
+ // for it anyway to make the code tidier.
+ return false, gtserror.Newf("unrecognized federation mode: %s", mode)
+ }
}
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, error) {