diff options
Diffstat (limited to 'internal/db/bundb/domain.go')
-rw-r--r-- | internal/db/bundb/domain.go | 115 |
1 files changed, 95 insertions, 20 deletions
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index fadb6dcf9..4cad75e4d 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -20,59 +20,134 @@ package bundb import ( "context" + "database/sql" "net/url" "strings" + "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/util" ) type domainDB struct { - conn *DBConn + conn *DBConn + cache *cache.DomainBlockCache } -func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { - if domain == "" || domain == config.GetHost() { - return false, nil +func (d *domainDB) CreateDomainBlock(ctx context.Context, block gtsmodel.DomainBlock) db.Error { + // Normalize to lowercase + block.Domain = strings.ToLower(block.Domain) + + // Attempt to insert new domain block + _, err := d.conn.NewInsert(). + Model(&block). + Exec(ctx, &block) + if err != nil { + return d.conn.ProcessError(err) } + // Cache this domain block + d.cache.Put(block.Domain, &block) + + return nil +} + +func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { + // Normalize to lowercase + domain = strings.ToLower(domain) + + // Check for easy case, domain referencing *us* + if domain == "" || domain == config.GetAccountDomain() { + return nil, db.ErrNoEntries + } + + // Check for already cached rblock + if block, ok := d.cache.GetByDomain(domain); ok { + // A 'nil' return value is a sentinel value for no block + if block == nil { + return nil, db.ErrNoEntries + } + + // Else, this block exists + return block, nil + } + + block := >smodel.DomainBlock{} + q := d.conn. NewSelect(). - Model(>smodel.DomainBlock{}). - ExcludeColumn("id", "created_at", "updated_at", "created_by_account_id", "private_comment", "public_comment", "obfuscate", "subscription_id"). + Model(block). Where("domain = ?", domain). Limit(1) - return d.conn.Exists(ctx, q) + // Query database for domain block + switch err := q.Scan(ctx); err { + // No error, block found + case nil: + d.cache.Put(domain, block) + return block, nil + + // No error, simply not found + case sql.ErrNoRows: + d.cache.Put(domain, nil) + return nil, db.ErrNoEntries + + // Any other db error + default: + return nil, d.conn.ProcessError(err) + } } -func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { - // filter out any doubles - uniqueDomains := util.UniqueStrings(domains) +func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { + // Normalize to lowercase + domain = strings.ToLower(domain) - for _, domain := range uniqueDomains { - if blocked, err := d.IsDomainBlocked(ctx, strings.ToLower(domain)); err != nil { + // Attempt to delete domain block + _, err := d.conn.NewDelete(). + Model((*gtsmodel.DomainBlock)(nil)). + Where("domain = ?", domain). + Exec(ctx, nil) + if err != nil { + return d.conn.ProcessError(err) + } + + // Clear domain from cache + d.cache.InvalidateByDomain(domain) + + return nil +} + +func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { + block, err := d.GetDomainBlock(ctx, domain) + if err == nil || err == db.ErrNoEntries { + return (block != nil), nil + } + return false, err +} + +func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { + for _, domain := range domains { + if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil { return false, err } else if blocked { return blocked, nil } } - - // no blocks found return false, nil } func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) { - domain := uri.Hostname() - return d.IsDomainBlocked(ctx, domain) + return d.IsDomainBlocked(ctx, uri.Hostname()) } func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) { - domains := []string{} for _, uri := range uris { - domains = append(domains, uri.Hostname()) + if blocked, err := d.IsDomainBlocked(ctx, uri.Hostname()); err != nil { + return false, err + } else if blocked { + return blocked, nil + } } - return d.AreDomainsBlocked(ctx, domains) + return false, nil } |