From d389e7b150df6ecd215c7b661b294ea153ad0103 Mon Sep 17 00:00:00 2001 From: Tobi Smethurst <31960611+tsmethurst@users.noreply.github.com> Date: Mon, 5 Jul 2021 13:23:03 +0200 Subject: Domain block (#76) * start work on admin domain blocking * move stuff around + further work on domain blocks * move + restructure processor * prep work for deleting account * tidy * go fmt * formatting * domain blocking more work * check domain blocks way earlier on * progress on delete account * delete more stuff when an account is gone * and more... * domain blocky block block * get individual domain block, delete a block --- internal/db/db.go | 13 +++---- internal/db/pg/instance.go | 83 +++++++++++++++++++++++++++++++++++++++++ internal/db/pg/instancestats.go | 52 -------------------------- internal/db/pg/pg.go | 33 ++++++++++------ 4 files changed, 111 insertions(+), 70 deletions(-) create mode 100644 internal/db/pg/instance.go delete mode 100644 internal/db/pg/instancestats.go (limited to 'internal/db') diff --git a/internal/db/db.go b/internal/db/db.go index 204f04c71..1ec02d22c 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -65,11 +65,6 @@ type DB interface { // In case of no entries, a 'no entries' error will be returned GetWhere(where []Where, i interface{}) error - // // GetWhereMany gets one entry where key = value for *ALL* parameters passed as "where". - // // That is, if you pass 2 'where' entries, with 1 being Key username and Value test, and the second - // // being Key domain and Value example.org, only entries will be returned where BOTH conditions are true. - // GetWhereMany(i interface{}, where ...model.Where) error - // GetAll will try to get all entries of type i. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned @@ -155,11 +150,11 @@ type DB interface { // CountStatusesByAccountID is a shortcut for the common action of counting statuses produced by accountID. CountStatusesByAccountID(accountID string) (int, error) - // GetStatusesByTimeDescending is a shortcut for getting the most recent statuses. accountID is optional, if not provided + // GetStatusesForAccount is a shortcut for getting the most recent statuses. accountID is optional, if not provided // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can // be very memory intensive so you probably shouldn't do this! // In case of no entries, a 'no entries' error will be returned - GetStatusesByTimeDescending(accountID string, statuses *[]gtsmodel.Status, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) error + GetStatusesForAccount(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) // GetLastStatusForAccountID simply gets the most recent status by the given account. // The given slice 'status' pointer will be set to the result of the query, whatever it is. @@ -261,6 +256,10 @@ type DB interface { // GetDomainCountForInstance returns the number of known instances known that the given domain federates with. GetDomainCountForInstance(domain string) (int, error) + + // GetAccountsForInstance returns a slice of accounts from the given instance, arranged by ID. + GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) + /* USEFUL CONVERSION FUNCTIONS */ diff --git a/internal/db/pg/instance.go b/internal/db/pg/instance.go new file mode 100644 index 000000000..2de0c5366 --- /dev/null +++ b/internal/db/pg/instance.go @@ -0,0 +1,83 @@ +package pg + +import ( + "github.com/go-pg/pg/v10" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) { + q := ps.conn.Model(&[]*gtsmodel.Account{}) + + if domain == ps.config.Host { + // if the domain is *this* domain, just count where the domain field is null + q = q.Where("? IS NULL", pg.Ident("domain")) + } else { + q = q.Where("domain = ?", domain) + } + + // don't count the instance account or suspended users + q = q.Where("username != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) + + return q.Count() +} + +func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) { + q := ps.conn.Model(&[]*gtsmodel.Status{}) + + if domain == ps.config.Host { + // if the domain is *this* domain, just count where local is true + q = q.Where("local = ?", true) + } else { + // join on the domain of the account + q = q.Join("JOIN accounts AS account ON account.id = status.account_id"). + Where("account.domain = ?", domain) + } + + return q.Count() +} + +func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) { + q := ps.conn.Model(&[]*gtsmodel.Instance{}) + + if domain == ps.config.Host { + // if the domain is *this* domain, just count other instances it knows about + // TODO: exclude domains that are blocked or silenced + q = q.Where("domain != ?", domain) + } else { + // TODO: implement federated domain counting properly for remote domains + return 0, nil + } + + return q.Count() +} + +func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) { + ps.log.Debug("GetAccountsForInstance") + + accounts := []*gtsmodel.Account{} + + q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC") + + if maxID != "" { + q = q.Where("id < ?", maxID) + } + + if limit > 0 { + q = q.Limit(limit) + } + + err := q.Select() + if err != nil { + if err == pg.ErrNoRows { + return nil, db.ErrNoEntries{} + } + return nil, err + } + + if len(accounts) == 0 { + return nil, db.ErrNoEntries{} + } + + return accounts, nil +} diff --git a/internal/db/pg/instancestats.go b/internal/db/pg/instancestats.go deleted file mode 100644 index b57591d7b..000000000 --- a/internal/db/pg/instancestats.go +++ /dev/null @@ -1,52 +0,0 @@ -package pg - -import ( - "github.com/go-pg/pg/v10" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Account{}) - - if domain == ps.config.Host { - // if the domain is *this* domain, just count where the domain field is null - q = q.Where("? IS NULL", pg.Ident("domain")) - } else { - q = q.Where("domain = ?", domain) - } - - // don't count the instance account or suspended users - q = q.Where("username != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) - - return q.Count() -} - -func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Status{}) - - if domain == ps.config.Host { - // if the domain is *this* domain, just count where local is true - q = q.Where("local = ?", true) - } else { - // join on the domain of the account - q = q.Join("JOIN accounts AS account ON account.id = status.account_id"). - Where("account.domain = ?", domain) - } - - return q.Count() -} - -func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Instance{}) - - if domain == ps.config.Host { - // if the domain is *this* domain, just count other instances it knows about - // TODO: exclude domains that are blocked or silenced - q = q.Where("domain != ?", domain) - } else { - // TODO: implement federated domain counting properly for remote domains - return 0, nil - } - - return q.Count() -} diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go index 2758d3c3d..1050f141e 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/pg/pg.go @@ -511,39 +511,50 @@ func (ps *postgresService) CountStatusesByAccountID(accountID string) (int, erro return count, nil } -func (ps *postgresService) GetStatusesByTimeDescending(accountID string, statuses *[]gtsmodel.Status, limit int, excludeReplies bool, maxID string, pinned bool, mediaOnly bool) error { - q := ps.conn.Model(statuses).Order("created_at DESC") +func (ps *postgresService) GetStatusesForAccount(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) { + ps.log.Debugf("getting statuses for account %s", accountID) + statuses := []*gtsmodel.Status{} + + q := ps.conn.Model(&statuses).Order("id DESC") if accountID != "" { q = q.Where("account_id = ?", accountID) } + if limit != 0 { q = q.Limit(limit) } + if excludeReplies { q = q.Where("? IS NULL", pg.Ident("in_reply_to_id")) } - if pinned { + + if pinnedOnly { q = q.Where("pinned = ?", true) } + if mediaOnly { q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) { return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil }) } + if maxID != "" { - s := >smodel.Status{} - if err := ps.conn.Model(s).Where("id = ?", maxID).Select(); err != nil { - return err - } - q = q.Where("status.created_at < ?", s.CreatedAt) + q = q.Where("id < ?", maxID) } + if err := q.Select(); err != nil { if err == pg.ErrNoRows { - return db.ErrNoEntries{} + return nil, db.ErrNoEntries{} } - return err + return nil, err } - return nil + + if len(statuses) == 0 { + return nil, db.ErrNoEntries{} + } + + ps.log.Debugf("returning statuses for account %s", accountID) + return statuses, nil } func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *gtsmodel.Status) error { -- cgit v1.3