summaryrefslogtreecommitdiff
path: root/internal/db/pg/instance.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/pg/instance.go')
-rw-r--r--internal/db/pg/instance.go39
1 files changed, 25 insertions, 14 deletions
diff --git a/internal/db/pg/instance.go b/internal/db/pg/instance.go
index c551b2a49..968832ca5 100644
--- a/internal/db/pg/instance.go
+++ b/internal/db/pg/instance.go
@@ -19,15 +19,26 @@
package pg
import (
+ "context"
+
"github.com/go-pg/pg/v10"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
- q := ps.conn.Model(&[]*gtsmodel.Account{})
+type instanceDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) {
+ q := i.conn.Model(&[]*gtsmodel.Account{})
- if domain == ps.config.Host {
+ if domain == i.config.Host {
// if the domain is *this* domain, just count where the domain field is null
q = q.Where("? IS NULL", pg.Ident("domain"))
} else {
@@ -40,10 +51,10 @@ func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
return q.Count()
}
-func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) {
- q := ps.conn.Model(&[]*gtsmodel.Status{})
+func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) {
+ q := i.conn.Model(&[]*gtsmodel.Status{})
- if domain == ps.config.Host {
+ if domain == i.config.Host {
// if the domain is *this* domain, just count where local is true
q = q.Where("local = ?", true)
} else {
@@ -55,10 +66,10 @@ func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error)
return q.Count()
}
-func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) {
- q := ps.conn.Model(&[]*gtsmodel.Instance{})
+func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) {
+ q := i.conn.Model(&[]*gtsmodel.Instance{})
- if domain == ps.config.Host {
+ if domain == i.config.Host {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
@@ -70,12 +81,12 @@ func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error)
return q.Count()
}
-func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) {
- ps.log.Debug("GetAccountsForInstance")
+func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
+ i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
- q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
+ q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
@@ -88,13 +99,13 @@ func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, l
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return nil, err
}
if len(accounts) == 0 {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return accounts, nil