diff options
Diffstat (limited to 'internal/db/pg/instance.go')
-rw-r--r-- | internal/db/pg/instance.go | 39 |
1 files changed, 25 insertions, 14 deletions
diff --git a/internal/db/pg/instance.go b/internal/db/pg/instance.go index c551b2a49..968832ca5 100644 --- a/internal/db/pg/instance.go +++ b/internal/db/pg/instance.go @@ -19,15 +19,26 @@ package pg import ( + "context" + "github.com/go-pg/pg/v10" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Account{}) +type instanceDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) { + q := i.conn.Model(&[]*gtsmodel.Account{}) - if domain == ps.config.Host { + if domain == i.config.Host { // if the domain is *this* domain, just count where the domain field is null q = q.Where("? IS NULL", pg.Ident("domain")) } else { @@ -40,10 +51,10 @@ func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) { return q.Count() } -func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Status{}) +func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) { + q := i.conn.Model(&[]*gtsmodel.Status{}) - if domain == ps.config.Host { + if domain == i.config.Host { // if the domain is *this* domain, just count where local is true q = q.Where("local = ?", true) } else { @@ -55,10 +66,10 @@ func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) return q.Count() } -func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Instance{}) +func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) { + q := i.conn.Model(&[]*gtsmodel.Instance{}) - if domain == ps.config.Host { + if domain == i.config.Host { // if the domain is *this* domain, just count other instances it knows about // exclude domains that are blocked q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) @@ -70,12 +81,12 @@ func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) return q.Count() } -func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) { - ps.log.Debug("GetAccountsForInstance") +func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { + i.log.Debug("GetAccountsForInstance") accounts := []*gtsmodel.Account{} - q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC") + q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC") if maxID != "" { q = q.Where("id < ?", maxID) @@ -88,13 +99,13 @@ func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, l err := q.Select() if err != nil { if err == pg.ErrNoRows { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return nil, err } if len(accounts) == 0 { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return accounts, nil |