summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2023-07-07 11:34:12 +0200
committerLibravatar GitHub <noreply@github.com>2023-07-07 11:34:12 +0200
commite70bf8a6c82e3d5c943550b364fc6f8120f6f07e (patch)
treef408ccff2e6f2451bf95ee9a5d96e5b678d686d5 /internal/db
parent[chore/performance] Remove remaining 'whereEmptyOrNull' funcs (#1946) (diff)
downloadgotosocial-e70bf8a6c82e3d5c943550b364fc6f8120f6f07e.tar.xz
[chore/bugfix] Domain block tidying up, Implement first pass of `207 Multi-Status` (#1886)
* [chore/refactor] update domain block processing * expose domain block import errors a lil better * move/remove unused query keys
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/bundb.go3
-rw-r--r--internal/db/bundb/domain.go29
-rw-r--r--internal/db/bundb/instance.go218
-rw-r--r--internal/db/domain.go12
-rw-r--r--internal/db/instance.go9
5 files changed, 248 insertions, 23 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 9d616954a..ee28800b5 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -179,7 +179,8 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
state: state,
},
Instance: &instanceDB{
- conn: conn,
+ conn: conn,
+ state: state,
},
List: &listDB{
conn: conn,
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
index 5c92645de..2e8ce2a6b 100644
--- a/internal/db/bundb/domain.go
+++ b/internal/db/bundb/domain.go
@@ -42,7 +42,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
return err
}
- // Attempt to store domain in DB
+ // Attempt to store domain block in DB
if _, err := d.conn.NewInsert().
Model(block).
Exec(ctx); err != nil {
@@ -82,6 +82,33 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
return &block, nil
}
+func (d *domainDB) GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) {
+ blocks := []*gtsmodel.DomainBlock{}
+
+ if err := d.conn.
+ NewSelect().
+ Model(&blocks).
+ Scan(ctx); err != nil {
+ return nil, d.conn.ProcessError(err)
+ }
+
+ return blocks, nil
+}
+
+func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, db.Error) {
+ var block gtsmodel.DomainBlock
+
+ q := d.conn.
+ NewSelect().
+ Model(&block).
+ Where("? = ?", bun.Ident("domain_block.id"), id)
+ if err := q.Scan(ctx); err != nil {
+ return nil, d.conn.ProcessError(err)
+ }
+
+ return &block, nil
+}
+
func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error {
// Normalize the domain as punycode
domain, err := util.Punify(domain)
diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go
index 95f6ad5b8..60d77600a 100644
--- a/internal/db/bundb/instance.go
+++ b/internal/db/bundb/instance.go
@@ -19,15 +19,23 @@ package bundb
import (
"context"
+ "time"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
type instanceDB struct {
- conn *DBConn
+ conn *DBConn
+ state *state.State
}
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
@@ -99,62 +107,236 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
}
func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, db.Error) {
- instance := &gtsmodel.Instance{}
+ // Normalize the domain as punycode
+ var err error
+ domain, err = util.Punify(domain)
+ if err != nil {
+ return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
+ }
- if err := i.conn.
- NewSelect().
- Model(instance).
- Where("? = ?", bun.Ident("instance.domain"), domain).
- Scan(ctx); err != nil {
- return nil, i.conn.ProcessError(err)
+ return i.getInstance(
+ ctx,
+ "Domain",
+ func(instance *gtsmodel.Instance) error {
+ return i.conn.NewSelect().
+ Model(instance).
+ Where("? = ?", bun.Ident("instance.domain"), domain).
+ Scan(ctx)
+ },
+ domain,
+ )
+}
+
+func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error) {
+ return i.getInstance(
+ ctx,
+ "ID",
+ func(instance *gtsmodel.Instance) error {
+ return i.conn.NewSelect().
+ Model(instance).
+ Where("? = ?", bun.Ident("instance.id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, db.Error) {
+ // Fetch instance from database cache with loader callback
+ instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) {
+ var instance gtsmodel.Instance
+
+ // Not cached! Perform database query.
+ if err := dbQuery(&instance); err != nil {
+ return nil, i.conn.ProcessError(err)
+ }
+
+ return &instance, nil
+ }, keyParts...)
+ if err != nil {
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return instance, nil
+ }
+
+ // Further populate the instance fields where applicable.
+ if err := i.populateInstance(ctx, instance); err != nil {
+ return nil, err
}
return instance, nil
}
+func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.Instance) error {
+ var (
+ err error
+ errs = make(gtserror.MultiError, 0, 2)
+ )
+
+ if instance.DomainBlockID != "" && instance.DomainBlock == nil {
+ // Instance domain block is not set, fetch from database.
+ instance.DomainBlock, err = i.state.DB.GetDomainBlock(
+ gtscontext.SetBarebones(ctx),
+ instance.Domain,
+ )
+ if err != nil {
+ errs.Append(gtserror.Newf("error populating instance domain block: %w", err))
+ }
+ }
+
+ if instance.ContactAccountID != "" && instance.ContactAccount == nil {
+ // Instance domain block is not set, fetch from database.
+ instance.ContactAccount, err = i.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ instance.ContactAccountID,
+ )
+ if err != nil {
+ errs.Append(gtserror.Newf("error populating instance contact account: %w", err))
+ }
+ }
+
+ return errs.Combine()
+}
+
+func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instance) error {
+ // Normalize the domain as punycode
+ var err error
+ instance.Domain, err = util.Punify(instance.Domain)
+ if err != nil {
+ return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
+ }
+
+ return i.state.Caches.GTS.Instance().Store(instance, func() error {
+ _, err := i.conn.NewInsert().Model(instance).Exec(ctx)
+ return i.conn.ProcessError(err)
+ })
+}
+
+func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error {
+ // Normalize the domain as punycode
+ var err error
+ instance.Domain, err = util.Punify(instance.Domain)
+ if err != nil {
+ return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
+ }
+
+ // Update the instance's last-updated
+ instance.UpdatedAt = time.Now()
+ if len(columns) != 0 {
+ columns = append(columns, "updated_at")
+ }
+
+ return i.state.Caches.GTS.Instance().Store(instance, func() error {
+ _, err := i.conn.
+ NewUpdate().
+ Model(instance).
+ Where("? = ?", bun.Ident("instance.id"), instance.ID).
+ Column(columns...).
+ Exec(ctx)
+ return i.conn.ProcessError(err)
+ })
+}
+
func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, db.Error) {
- instances := []*gtsmodel.Instance{}
+ instanceIDs := []string{}
q := i.conn.
NewSelect().
- Model(&instances).
+ TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
+ // Select just the IDs of each instance.
+ Column("instance.id").
+ // Exclude our own instance.
Where("? != ?", bun.Ident("instance.domain"), config.GetHost())
if !includeSuspended {
q = q.Where("? IS NULL", bun.Ident("instance.suspended_at"))
}
- if err := q.Scan(ctx); err != nil {
+ if err := q.Scan(ctx, &instanceIDs); err != nil {
return nil, i.conn.ProcessError(err)
}
+ if len(instanceIDs) == 0 {
+ return make([]*gtsmodel.Instance, 0), nil
+ }
+
+ instances := make([]*gtsmodel.Instance, 0, len(instanceIDs))
+
+ for _, id := range instanceIDs {
+ // Select each instance by its ID.
+ instance, err := i.GetInstanceByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error getting instance %q: %v", id, err)
+ continue
+ }
+
+ // Append to return slice.
+ instances = append(instances, instance)
+ }
+
return instances, nil
}
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
- accounts := []*gtsmodel.Account{}
+ // Ensure reasonable
+ if limit < 0 {
+ limit = 0
+ }
+
+ // Normalize the domain as punycode.
+ var err error
+ domain, err = util.Punify(domain)
+ if err != nil {
+ return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
+ }
+
+ // Make educated guess for slice size
+ accountIDs := make([]string, 0, limit)
- q := i.conn.NewSelect().
- Model(&accounts).
+ q := i.conn.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ // Select just the account ID.
+ Column("account.id").
+ // Select accounts belonging to given domain.
Where("? = ?", bun.Ident("account.domain"), domain).
Order("account.id DESC")
- if maxID != "" {
- q = q.Where("? < ?", bun.Ident("account.id"), maxID)
+ if maxID == "" {
+ maxID = id.Highest
}
+ q = q.Where("? < ?", bun.Ident("account.id"), maxID)
if limit > 0 {
q = q.Limit(limit)
}
- if err := q.Scan(ctx); err != nil {
+ if err := q.Scan(ctx, &accountIDs); err != nil {
return nil, i.conn.ProcessError(err)
}
- if len(accounts) == 0 {
+ // Catch case of no accounts early.
+ count := len(accountIDs)
+ if count == 0 {
return nil, db.ErrNoEntries
}
+ // Select each account by its ID.
+ accounts := make([]*gtsmodel.Account, 0, count)
+ for _, id := range accountIDs {
+ account, err := i.state.DB.GetAccountByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error getting account %q: %v", id, err)
+ continue
+ }
+
+ // Append to return slice.
+ accounts = append(accounts, account)
+ }
+
return accounts, nil
}
diff --git a/internal/db/domain.go b/internal/db/domain.go
index 8918d6fe8..d859752af 100644
--- a/internal/db/domain.go
+++ b/internal/db/domain.go
@@ -26,13 +26,19 @@ import (
// Domain contains DB functions related to domains and domain blocks.
type Domain interface {
- // CreateDomainBlock ...
+ // CreateDomainBlock puts the given instance-level domain block into the database.
CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) Error
- // GetDomainBlock ...
+ // GetDomainBlock returns one instance-level domain block with the given domain, if it exists.
GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, Error)
- // DeleteDomainBlock ...
+ // GetDomainBlockByID returns one instance-level domain block with the given id, if it exists.
+ GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, Error)
+
+ // GetDomainBlocks returns all instance-level domain blocks currently enforced by this instance.
+ GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error)
+
+ // DeleteDomainBlock deletes an instance-level domain block with the given domain, if it exists.
DeleteDomainBlock(ctx context.Context, domain string) Error
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
diff --git a/internal/db/instance.go b/internal/db/instance.go
index 3166a0a18..ab40c7a82 100644
--- a/internal/db/instance.go
+++ b/internal/db/instance.go
@@ -37,6 +37,15 @@ type Instance interface {
// GetInstance returns the instance entry for the given domain, if it exists.
GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, Error)
+ // GetInstanceByID returns the instance entry corresponding to the given id, if it exists.
+ GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error)
+
+ // PutInstance inserts the given instance into the database.
+ PutInstance(ctx context.Context, instance *gtsmodel.Instance) error
+
+ // UpdateInstance updates the given instance entry.
+ UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error
+
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)