summaryrefslogtreecommitdiff
path: root/internal/db/bundb/domain.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/domain.go')
-rw-r--r--internal/db/bundb/domain.go115
1 files changed, 95 insertions, 20 deletions
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
index fadb6dcf9..4cad75e4d 100644
--- a/internal/db/bundb/domain.go
+++ b/internal/db/bundb/domain.go
@@ -20,59 +20,134 @@ package bundb
import (
"context"
+ "database/sql"
"net/url"
"strings"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/util"
)
type domainDB struct {
- conn *DBConn
+ conn *DBConn
+ cache *cache.DomainBlockCache
}
-func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
- if domain == "" || domain == config.GetHost() {
- return false, nil
+func (d *domainDB) CreateDomainBlock(ctx context.Context, block gtsmodel.DomainBlock) db.Error {
+ // Normalize to lowercase
+ block.Domain = strings.ToLower(block.Domain)
+
+ // Attempt to insert new domain block
+ _, err := d.conn.NewInsert().
+ Model(&block).
+ Exec(ctx, &block)
+ if err != nil {
+ return d.conn.ProcessError(err)
}
+ // Cache this domain block
+ d.cache.Put(block.Domain, &block)
+
+ return nil
+}
+
+func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) {
+ // Normalize to lowercase
+ domain = strings.ToLower(domain)
+
+ // Check for easy case, domain referencing *us*
+ if domain == "" || domain == config.GetAccountDomain() {
+ return nil, db.ErrNoEntries
+ }
+
+ // Check for already cached rblock
+ if block, ok := d.cache.GetByDomain(domain); ok {
+ // A 'nil' return value is a sentinel value for no block
+ if block == nil {
+ return nil, db.ErrNoEntries
+ }
+
+ // Else, this block exists
+ return block, nil
+ }
+
+ block := &gtsmodel.DomainBlock{}
+
q := d.conn.
NewSelect().
- Model(&gtsmodel.DomainBlock{}).
- ExcludeColumn("id", "created_at", "updated_at", "created_by_account_id", "private_comment", "public_comment", "obfuscate", "subscription_id").
+ Model(block).
Where("domain = ?", domain).
Limit(1)
- return d.conn.Exists(ctx, q)
+ // Query database for domain block
+ switch err := q.Scan(ctx); err {
+ // No error, block found
+ case nil:
+ d.cache.Put(domain, block)
+ return block, nil
+
+ // No error, simply not found
+ case sql.ErrNoRows:
+ d.cache.Put(domain, nil)
+ return nil, db.ErrNoEntries
+
+ // Any other db error
+ default:
+ return nil, d.conn.ProcessError(err)
+ }
}
-func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
- // filter out any doubles
- uniqueDomains := util.UniqueStrings(domains)
+func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error {
+ // Normalize to lowercase
+ domain = strings.ToLower(domain)
- for _, domain := range uniqueDomains {
- if blocked, err := d.IsDomainBlocked(ctx, strings.ToLower(domain)); err != nil {
+ // Attempt to delete domain block
+ _, err := d.conn.NewDelete().
+ Model((*gtsmodel.DomainBlock)(nil)).
+ Where("domain = ?", domain).
+ Exec(ctx, nil)
+ if err != nil {
+ return d.conn.ProcessError(err)
+ }
+
+ // Clear domain from cache
+ d.cache.InvalidateByDomain(domain)
+
+ return nil
+}
+
+func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
+ block, err := d.GetDomainBlock(ctx, domain)
+ if err == nil || err == db.ErrNoEntries {
+ return (block != nil), nil
+ }
+ return false, err
+}
+
+func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
+ for _, domain := range domains {
+ if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil {
return false, err
} else if blocked {
return blocked, nil
}
}
-
- // no blocks found
return false, nil
}
func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) {
- domain := uri.Hostname()
- return d.IsDomainBlocked(ctx, domain)
+ return d.IsDomainBlocked(ctx, uri.Hostname())
}
func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) {
- domains := []string{}
for _, uri := range uris {
- domains = append(domains, uri.Hostname())
+ if blocked, err := d.IsDomainBlocked(ctx, uri.Hostname()); err != nil {
+ return false, err
+ } else if blocked {
+ return blocked, nil
+ }
}
- return d.AreDomainsBlocked(ctx, domains)
+ return false, nil
}