diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/bundb/bundb.go | 3 | ||||
| -rw-r--r-- | internal/db/bundb/domain.go | 29 | ||||
| -rw-r--r-- | internal/db/bundb/instance.go | 218 | ||||
| -rw-r--r-- | internal/db/domain.go | 12 | ||||
| -rw-r--r-- | internal/db/instance.go | 9 | 
5 files changed, 248 insertions, 23 deletions
| diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 9d616954a..ee28800b5 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -179,7 +179,8 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  			state: state,  		},  		Instance: &instanceDB{ -			conn: conn, +			conn:  conn, +			state: state,  		},  		List: &listDB{  			conn:  conn, diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 5c92645de..2e8ce2a6b 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -42,7 +42,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain  		return err  	} -	// Attempt to store domain in DB +	// Attempt to store domain block in DB  	if _, err := d.conn.NewInsert().  		Model(block).  		Exec(ctx); err != nil { @@ -82,6 +82,33 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel  	return &block, nil  } +func (d *domainDB) GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) { +	blocks := []*gtsmodel.DomainBlock{} + +	if err := d.conn. +		NewSelect(). +		Model(&blocks). +		Scan(ctx); err != nil { +		return nil, d.conn.ProcessError(err) +	} + +	return blocks, nil +} + +func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, db.Error) { +	var block gtsmodel.DomainBlock + +	q := d.conn. +		NewSelect(). +		Model(&block). +		Where("? = ?", bun.Ident("domain_block.id"), id) +	if err := q.Scan(ctx); err != nil { +		return nil, d.conn.ProcessError(err) +	} + +	return &block, nil +} +  func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error {  	// Normalize the domain as punycode  	domain, err := util.Punify(domain) diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 95f6ad5b8..60d77600a 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -19,15 +19,23 @@ package bundb  import (  	"context" +	"time"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/id" +	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/util"  	"github.com/uptrace/bun"  )  type instanceDB struct { -	conn *DBConn +	conn  *DBConn +	state *state.State  }  func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { @@ -99,62 +107,236 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i  }  func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, db.Error) { -	instance := >smodel.Instance{} +	// Normalize the domain as punycode +	var err error +	domain, err = util.Punify(domain) +	if err != nil { +		return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) +	} -	if err := i.conn. -		NewSelect(). -		Model(instance). -		Where("? = ?", bun.Ident("instance.domain"), domain). -		Scan(ctx); err != nil { -		return nil, i.conn.ProcessError(err) +	return i.getInstance( +		ctx, +		"Domain", +		func(instance *gtsmodel.Instance) error { +			return i.conn.NewSelect(). +				Model(instance). +				Where("? = ?", bun.Ident("instance.domain"), domain). +				Scan(ctx) +		}, +		domain, +	) +} + +func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error) { +	return i.getInstance( +		ctx, +		"ID", +		func(instance *gtsmodel.Instance) error { +			return i.conn.NewSelect(). +				Model(instance). +				Where("? = ?", bun.Ident("instance.id"), id). +				Scan(ctx) +		}, +		id, +	) +} + +func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, db.Error) { +	// Fetch instance from database cache with loader callback +	instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) { +		var instance gtsmodel.Instance + +		// Not cached! Perform database query. +		if err := dbQuery(&instance); err != nil { +			return nil, i.conn.ProcessError(err) +		} + +		return &instance, nil +	}, keyParts...) +	if err != nil { +		return nil, err +	} + +	if gtscontext.Barebones(ctx) { +		// no need to fully populate. +		return instance, nil +	} + +	// Further populate the instance fields where applicable. +	if err := i.populateInstance(ctx, instance); err != nil { +		return nil, err  	}  	return instance, nil  } +func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.Instance) error { +	var ( +		err  error +		errs = make(gtserror.MultiError, 0, 2) +	) + +	if instance.DomainBlockID != "" && instance.DomainBlock == nil { +		// Instance domain block is not set, fetch from database. +		instance.DomainBlock, err = i.state.DB.GetDomainBlock( +			gtscontext.SetBarebones(ctx), +			instance.Domain, +		) +		if err != nil { +			errs.Append(gtserror.Newf("error populating instance domain block: %w", err)) +		} +	} + +	if instance.ContactAccountID != "" && instance.ContactAccount == nil { +		// Instance domain block is not set, fetch from database. +		instance.ContactAccount, err = i.state.DB.GetAccountByID( +			gtscontext.SetBarebones(ctx), +			instance.ContactAccountID, +		) +		if err != nil { +			errs.Append(gtserror.Newf("error populating instance contact account: %w", err)) +		} +	} + +	return errs.Combine() +} + +func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instance) error { +	// Normalize the domain as punycode +	var err error +	instance.Domain, err = util.Punify(instance.Domain) +	if err != nil { +		return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err) +	} + +	return i.state.Caches.GTS.Instance().Store(instance, func() error { +		_, err := i.conn.NewInsert().Model(instance).Exec(ctx) +		return i.conn.ProcessError(err) +	}) +} + +func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error { +	// Normalize the domain as punycode +	var err error +	instance.Domain, err = util.Punify(instance.Domain) +	if err != nil { +		return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err) +	} + +	// Update the instance's last-updated +	instance.UpdatedAt = time.Now() +	if len(columns) != 0 { +		columns = append(columns, "updated_at") +	} + +	return i.state.Caches.GTS.Instance().Store(instance, func() error { +		_, err := i.conn. +			NewUpdate(). +			Model(instance). +			Where("? = ?", bun.Ident("instance.id"), instance.ID). +			Column(columns...). +			Exec(ctx) +		return i.conn.ProcessError(err) +	}) +} +  func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, db.Error) { -	instances := []*gtsmodel.Instance{} +	instanceIDs := []string{}  	q := i.conn.  		NewSelect(). -		Model(&instances). +		TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). +		// Select just the IDs of each instance. +		Column("instance.id"). +		// Exclude our own instance.  		Where("? != ?", bun.Ident("instance.domain"), config.GetHost())  	if !includeSuspended {  		q = q.Where("? IS NULL", bun.Ident("instance.suspended_at"))  	} -	if err := q.Scan(ctx); err != nil { +	if err := q.Scan(ctx, &instanceIDs); err != nil {  		return nil, i.conn.ProcessError(err)  	} +	if len(instanceIDs) == 0 { +		return make([]*gtsmodel.Instance, 0), nil +	} + +	instances := make([]*gtsmodel.Instance, 0, len(instanceIDs)) + +	for _, id := range instanceIDs { +		// Select each instance by its ID. +		instance, err := i.GetInstanceByID(ctx, id) +		if err != nil { +			log.Errorf(ctx, "error getting instance %q: %v", id, err) +			continue +		} + +		// Append to return slice. +		instances = append(instances, instance) +	} +  	return instances, nil  }  func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { -	accounts := []*gtsmodel.Account{} +	// Ensure reasonable +	if limit < 0 { +		limit = 0 +	} + +	// Normalize the domain as punycode. +	var err error +	domain, err = util.Punify(domain) +	if err != nil { +		return nil, gtserror.Newf("error punifying domain %s: %w", domain, err) +	} + +	// Make educated guess for slice size +	accountIDs := make([]string, 0, limit) -	q := i.conn.NewSelect(). -		Model(&accounts). +	q := i.conn. +		NewSelect(). +		TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). +		// Select just the account ID. +		Column("account.id"). +		// Select accounts belonging to given domain.  		Where("? = ?", bun.Ident("account.domain"), domain).  		Order("account.id DESC") -	if maxID != "" { -		q = q.Where("? < ?", bun.Ident("account.id"), maxID) +	if maxID == "" { +		maxID = id.Highest  	} +	q = q.Where("? < ?", bun.Ident("account.id"), maxID)  	if limit > 0 {  		q = q.Limit(limit)  	} -	if err := q.Scan(ctx); err != nil { +	if err := q.Scan(ctx, &accountIDs); err != nil {  		return nil, i.conn.ProcessError(err)  	} -	if len(accounts) == 0 { +	// Catch case of no accounts early. +	count := len(accountIDs) +	if count == 0 {  		return nil, db.ErrNoEntries  	} +	// Select each account by its ID. +	accounts := make([]*gtsmodel.Account, 0, count) +	for _, id := range accountIDs { +		account, err := i.state.DB.GetAccountByID(ctx, id) +		if err != nil { +			log.Errorf(ctx, "error getting account %q: %v", id, err) +			continue +		} + +		// Append to return slice. +		accounts = append(accounts, account) +	} +  	return accounts, nil  } diff --git a/internal/db/domain.go b/internal/db/domain.go index 8918d6fe8..d859752af 100644 --- a/internal/db/domain.go +++ b/internal/db/domain.go @@ -26,13 +26,19 @@ import (  // Domain contains DB functions related to domains and domain blocks.  type Domain interface { -	// CreateDomainBlock ... +	// CreateDomainBlock puts the given instance-level domain block into the database.  	CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) Error -	// GetDomainBlock ... +	// GetDomainBlock returns one instance-level domain block with the given domain, if it exists.  	GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, Error) -	// DeleteDomainBlock ... +	// GetDomainBlockByID returns one instance-level domain block with the given id, if it exists. +	GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, Error) + +	// GetDomainBlocks returns all instance-level domain blocks currently enforced by this instance. +	GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) + +	// DeleteDomainBlock deletes an instance-level domain block with the given domain, if it exists.  	DeleteDomainBlock(ctx context.Context, domain string) Error  	// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`). diff --git a/internal/db/instance.go b/internal/db/instance.go index 3166a0a18..ab40c7a82 100644 --- a/internal/db/instance.go +++ b/internal/db/instance.go @@ -37,6 +37,15 @@ type Instance interface {  	// GetInstance returns the instance entry for the given domain, if it exists.  	GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, Error) +	// GetInstanceByID returns the instance entry corresponding to the given id, if it exists. +	GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error) + +	// PutInstance inserts the given instance into the database. +	PutInstance(ctx context.Context, instance *gtsmodel.Instance) error + +	// UpdateInstance updates the given instance entry. +	UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error +  	// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.  	GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) | 
