diff options
author | 2024-04-16 13:10:13 +0200 | |
---|---|---|
committer | 2024-04-16 13:10:13 +0200 | |
commit | 3cceed11b28b5f42a653d85ed779d652fd8c26ad (patch) | |
tree | 0a7f0994e477609ca705a45f382dfb62056b196e /internal/db | |
parent | [performance] cached oauth database types (#2838) (diff) | |
download | gotosocial-3cceed11b28b5f42a653d85ed779d652fd8c26ad.tar.xz |
[feature/performance] Store account stats in separate table (#2831)
* [feature/performance] Store account stats in separate table, get stats from remote
* test account stats
* add some missing increment / decrement calls
* change stats function signatures
* rejig logging a bit
* use lock when updating stats
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/account.go | 34 | ||||
-rw-r--r-- | internal/db/bundb/account.go | 267 | ||||
-rw-r--r-- | internal/db/bundb/account_test.go | 79 | ||||
-rw-r--r-- | internal/db/bundb/migrations/20240414122348_account_stats_model.go | 52 | ||||
-rw-r--r-- | internal/db/bundb/relationship.go | 63 | ||||
-rw-r--r-- | internal/db/bundb/relationship_test.go | 28 | ||||
-rw-r--r-- | internal/db/bundb/upsert.go | 6 | ||||
-rw-r--r-- | internal/db/relationship.go | 40 |
8 files changed, 349 insertions, 220 deletions
diff --git a/internal/db/account.go b/internal/db/account.go index 7cdf7b57f..dec36d2ac 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -20,7 +20,6 @@ package db import ( "context" "net/netip" - "time" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/paging" @@ -100,12 +99,6 @@ type Account interface { // GetAccountsUsingEmoji fetches all account models using emoji with given ID stored in their 'emojis' column. GetAccountsUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Account, error) - // GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID. - CountAccountStatuses(ctx context.Context, accountID string) (int, error) - - // CountAccountPinned returns the total number of pinned statuses owned by account with the given id. - CountAccountPinned(ctx context.Context, accountID string) (int, error) - // GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can // be very memory intensive so you probably shouldn't do this! @@ -128,13 +121,6 @@ type Account interface { // In the case of no statuses, this function will return db.ErrNoEntries. GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) - // GetAccountLastPosted simply gets the timestamp of the most recent post by the account. - // - // If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned. - // - // The returned time will be zero if account has never posted anything. - GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) - // SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment. SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error @@ -150,4 +136,24 @@ type Account interface { // Update local account settings. UpdateAccountSettings(ctx context.Context, settings *gtsmodel.AccountSettings, columns ...string) error + + // PopulateAccountStats gets (or creates and gets) account stats for + // the given account, and attaches them to the account model. + PopulateAccountStats(ctx context.Context, account *gtsmodel.Account) error + + // RegenerateAccountStats creates, upserts, and returns stats + // for the given account, and attaches them to the account model. + // + // Unlike GetAccountStats, it will always get the database stats fresh. + // This can be used to "refresh" stats. + // + // Because this involves database calls that can be expensive (on Postgres + // specifically), callers should prefer GetAccountStats in 99% of cases. + RegenerateAccountStats(ctx context.Context, account *gtsmodel.Account) error + + // Update account stats. + UpdateAccountStats(ctx context.Context, stats *gtsmodel.AccountStats, columns ...string) error + + // DeleteAccountStats deletes the accountStats entry for the given accountID. + DeleteAccountStats(ctx context.Context, accountID string) error } diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 45e67c10b..2b3c78aff 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -630,6 +630,13 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou } } + if account.Stats == nil { + // Get / Create stats for this account. + if err := a.state.DB.PopulateAccountStats(ctx, account); err != nil { + errs.Appendf("error populating account stats: %w", err) + } + } + return errs.Combine() } @@ -735,31 +742,6 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) error { }) } -func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) { - createdAt := time.Time{} - - q := a.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Column("status.created_at"). - Where("? = ?", bun.Ident("status.account_id"), accountID). - Order("status.id DESC"). - Limit(1) - - if webOnly { - q = q. - Where("? IS NULL", bun.Ident("status.in_reply_to_uri")). - Where("? IS NULL", bun.Ident("status.boost_of_id")). - Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). - Where("? = ?", bun.Ident("status.federated"), true) - } - - if err := q.Scan(ctx, &createdAt); err != nil { - return time.Time{}, err - } - return createdAt, nil -} - func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error { if *mediaAttachment.Avatar && *mediaAttachment.Header { return errors.New("one media attachment cannot be both header and avatar") @@ -845,59 +827,6 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g return *faves, nil } -func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) { - counts, err := a.getAccountStatusCounts(ctx, accountID) - return counts.Statuses, err -} - -func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) { - counts, err := a.getAccountStatusCounts(ctx, accountID) - return counts.Pinned, err -} - -func (a *accountDB) getAccountStatusCounts(ctx context.Context, accountID string) (struct { - Statuses int - Pinned int -}, error) { - // Check for an already cached copy of account status counts. - counts, ok := a.state.Caches.GTS.AccountCounts.Get(accountID) - if ok { - return counts, nil - } - - if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - var err error - - // Scan database for account statuses. - counts.Statuses, err = tx.NewSelect(). - Table("statuses"). - Where("? = ?", bun.Ident("account_id"), accountID). - Count(ctx) - if err != nil { - return err - } - - // Scan database for pinned statuses. - counts.Pinned, err = tx.NewSelect(). - Table("statuses"). - Where("? = ?", bun.Ident("account_id"), accountID). - Where("? IS NOT NULL", bun.Ident("pinned_at")). - Count(ctx) - if err != nil { - return err - } - - return nil - }); err != nil { - return counts, err - } - - // Store this account counts result in the cache. - a.state.Caches.GTS.AccountCounts.Set(accountID, counts) - - return counts, nil -} - func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) { // Ensure reasonable if limit < 0 { @@ -1147,3 +1076,185 @@ func (a *accountDB) UpdateAccountSettings( return nil }) } + +func (a *accountDB) PopulateAccountStats(ctx context.Context, account *gtsmodel.Account) error { + // Fetch stats from db cache with loader callback. + stats, err := a.state.Caches.GTS.AccountStats.LoadOne( + "AccountID", + func() (*gtsmodel.AccountStats, error) { + // Not cached! Perform database query. + var stats gtsmodel.AccountStats + if err := a.db. + NewSelect(). + Model(&stats). + Where("? = ?", bun.Ident("account_stats.account_id"), account.ID). + Scan(ctx); err != nil { + return nil, err + } + return &stats, nil + }, + account.ID, + ) + + if err != nil && !errors.Is(err, db.ErrNoEntries) { + // Real error. + return err + } + + if stats == nil { + // Don't have stats yet, generate them. + return a.RegenerateAccountStats(ctx, account) + } + + // We have a stats, attach + // it to the account. + account.Stats = stats + + // Check if this is a local + // stats by looking at the + // account they pertain to. + if account.IsRemote() { + // Account is remote. Updating + // stats for remote accounts is + // handled in the dereferencer. + // + // Nothing more to do! + return nil + } + + // Stats account is local, check + // if we need to regenerate. + const statsFreshness = 48 * time.Hour + expiry := stats.RegeneratedAt.Add(statsFreshness) + if time.Now().After(expiry) { + // Stats have expired, regenerate them. + return a.RegenerateAccountStats(ctx, account) + } + + // Stats are still fresh. + return nil +} + +func (a *accountDB) RegenerateAccountStats(ctx context.Context, account *gtsmodel.Account) error { + // Initialize a new stats struct. + stats := >smodel.AccountStats{ + AccountID: account.ID, + RegeneratedAt: time.Now(), + } + + // Count followers outside of transaction since + // it uses a cache + requires its own db calls. + followerIDs, err := a.state.DB.GetAccountFollowerIDs(ctx, account.ID, nil) + if err != nil { + return err + } + stats.FollowersCount = util.Ptr(len(followerIDs)) + + // Count following outside of transaction since + // it uses a cache + requires its own db calls. + followIDs, err := a.state.DB.GetAccountFollowIDs(ctx, account.ID, nil) + if err != nil { + return err + } + stats.FollowingCount = util.Ptr(len(followIDs)) + + // Count follow requests outside of transaction since + // it uses a cache + requires its own db calls. + followRequestIDs, err := a.state.DB.GetAccountFollowRequestIDs(ctx, account.ID, nil) + if err != nil { + return err + } + stats.FollowRequestsCount = util.Ptr(len(followRequestIDs)) + + // Populate remaining stats struct fields. + // This can be done inside a transaction. + if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + var err error + + // Scan database for account statuses. + statusesCount, err := tx.NewSelect(). + Table("statuses"). + Where("? = ?", bun.Ident("account_id"), account.ID). + Count(ctx) + if err != nil { + return err + } + stats.StatusesCount = &statusesCount + + // Scan database for pinned statuses. + statusesPinnedCount, err := tx.NewSelect(). + Table("statuses"). + Where("? = ?", bun.Ident("account_id"), account.ID). + Where("? IS NOT NULL", bun.Ident("pinned_at")). + Count(ctx) + if err != nil { + return err + } + stats.StatusesPinnedCount = &statusesPinnedCount + + // Scan database for last status. + lastStatusAt := time.Time{} + err = tx. + NewSelect(). + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Column("status.created_at"). + Where("? = ?", bun.Ident("status.account_id"), account.ID). + Order("status.id DESC"). + Limit(1). + Scan(ctx, &lastStatusAt) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return err + } + stats.LastStatusAt = lastStatusAt + + return nil + }); err != nil { + return err + } + + // Upsert this stats in case a race + // meant someone else inserted it first. + if err := a.state.Caches.GTS.AccountStats.Store(stats, func() error { + if _, err := NewUpsert(a.db). + Model(stats). + Constraint("account_id"). + Exec(ctx); err != nil { + return err + } + return nil + }); err != nil { + return err + } + + account.Stats = stats + return nil +} + +func (a *accountDB) UpdateAccountStats(ctx context.Context, stats *gtsmodel.AccountStats, columns ...string) error { + return a.state.Caches.GTS.AccountStats.Store(stats, func() error { + if _, err := a.db. + NewUpdate(). + Model(stats). + Column(columns...). + Where("? = ?", bun.Ident("account_stats.account_id"), stats.AccountID). + Exec(ctx); err != nil { + return err + } + + return nil + }) +} + +func (a *accountDB) DeleteAccountStats(ctx context.Context, accountID string) error { + defer a.state.Caches.GTS.AccountStats.Invalidate("AccountID", accountID) + + if _, err := a.db. + NewDelete(). + Table("account_stats"). + Where("? = ?", bun.Ident("account_id"), accountID). + Exec(ctx); err != nil { + return err + } + + return nil +} diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index dd96543b6..ea211e16f 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -220,6 +220,8 @@ func (suite *AccountTestSuite) TestGetAccountBy() { a2.Emojis = nil a1.Settings = nil a2.Settings = nil + a1.Stats = nil + a2.Stats = nil // Clear database-set fields. a1.CreatedAt = time.Time{} @@ -413,18 +415,6 @@ func (suite *AccountTestSuite) TestUpdateAccount() { suite.WithinDuration(time.Now(), noCache.UpdatedAt, 5*time.Second) } -func (suite *AccountTestSuite) TestGetAccountLastPosted() { - lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, false) - suite.NoError(err) - suite.EqualValues(1702200240, lastPosted.Unix()) -} - -func (suite *AccountTestSuite) TestGetAccountLastPostedWebOnly() { - lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, true) - suite.NoError(err) - suite.EqualValues(1702200240, lastPosted.Unix()) -} - func (suite *AccountTestSuite) TestInsertAccountWithDefaults() { key, err := rsa.GenerateKey(rand.Reader, 2048) suite.NoError(err) @@ -466,22 +456,6 @@ func (suite *AccountTestSuite) TestGetAccountPinnedStatusesNothingPinned() { suite.Empty(statuses) // This account has nothing pinned. } -func (suite *AccountTestSuite) TestCountAccountPinnedSomeResults() { - testAccount := suite.testAccounts["admin_account"] - - pinned, err := suite.db.CountAccountPinned(context.Background(), testAccount.ID) - suite.NoError(err) - suite.Equal(pinned, 2) // This account has 2 statuses pinned. -} - -func (suite *AccountTestSuite) TestCountAccountPinnedNothingPinned() { - testAccount := suite.testAccounts["local_account_1"] - - pinned, err := suite.db.CountAccountPinned(context.Background(), testAccount.ID) - suite.NoError(err) - suite.Equal(pinned, 0) // This account has nothing pinned. -} - func (suite *AccountTestSuite) TestPopulateAccountWithUnknownMovedToURI() { testAccount := >smodel.Account{} *testAccount = *suite.testAccounts["local_account_1"] @@ -676,6 +650,55 @@ func (suite *AccountTestSuite) TestGetPendingAccounts() { suite.Len(accounts, 1) } +func (suite *AccountTestSuite) TestAccountStatsAll() { + ctx := context.Background() + for _, account := range suite.testAccounts { + // Get stats for the first time. They + // should all be generated now since + // they're not stored in the test rig. + if err := suite.db.PopulateAccountStats(ctx, account); err != nil { + suite.FailNow(err.Error()) + } + stats := account.Stats + suite.NotNil(stats) + suite.WithinDuration(time.Now(), stats.RegeneratedAt, 5*time.Second) + + // Get stats a second time. They shouldn't + // be regenerated since we just did it. + if err := suite.db.PopulateAccountStats(ctx, account); err != nil { + suite.FailNow(err.Error()) + } + stats2 := account.Stats + suite.NotNil(stats2) + suite.Equal(stats2.RegeneratedAt, stats.RegeneratedAt) + + // Update the stats to indicate they're out of date. + stats2.RegeneratedAt = time.Now().Add(-72 * time.Hour) + if err := suite.db.UpdateAccountStats(ctx, stats2, "regenerated_at"); err != nil { + suite.FailNow(err.Error()) + } + + // Get stats for a third time, they + // should get regenerated now, but + // only for local accounts. + if err := suite.db.PopulateAccountStats(ctx, account); err != nil { + suite.FailNow(err.Error()) + } + stats3 := account.Stats + suite.NotNil(stats3) + if account.IsLocal() { + suite.True(stats3.RegeneratedAt.After(stats.RegeneratedAt)) + } else { + suite.False(stats3.RegeneratedAt.After(stats.RegeneratedAt)) + } + + // Now delete the stats. + if err := suite.db.DeleteAccountStats(ctx, account.ID); err != nil { + suite.FailNow(err.Error()) + } + } +} + func TestAccountTestSuite(t *testing.T) { suite.Run(t, new(AccountTestSuite)) } diff --git a/internal/db/bundb/migrations/20240414122348_account_stats_model.go b/internal/db/bundb/migrations/20240414122348_account_stats_model.go new file mode 100644 index 000000000..450ca04d4 --- /dev/null +++ b/internal/db/bundb/migrations/20240414122348_account_stats_model.go @@ -0,0 +1,52 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Create new AccountStats table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.AccountStats{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 1c533af39..052f29cb3 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -112,7 +112,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount } func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { - followIDs, err := r.getAccountFollowIDs(ctx, accountID, page) + followIDs, err := r.GetAccountFollowIDs(ctx, accountID, page) if err != nil { return nil, err } @@ -120,7 +120,7 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string } func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) + followIDs, err := r.GetAccountLocalFollowIDs(ctx, accountID) if err != nil { return nil, err } @@ -128,7 +128,7 @@ func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID s } func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { - followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, page) + followerIDs, err := r.GetAccountFollowerIDs(ctx, accountID, page) if err != nil { return nil, err } @@ -136,7 +136,7 @@ func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID stri } func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) + followerIDs, err := r.GetAccountLocalFollowerIDs(ctx, accountID) if err != nil { return nil, err } @@ -144,7 +144,7 @@ func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID } func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { - followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, page) + followReqIDs, err := r.GetAccountFollowRequestIDs(ctx, accountID, page) if err != nil { return nil, err } @@ -152,7 +152,7 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID } func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { - followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, page) + followReqIDs, err := r.GetAccountFollowRequestingIDs(ctx, accountID, page) if err != nil { return nil, err } @@ -160,49 +160,14 @@ func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, account } func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) { - blockIDs, err := r.getAccountBlockIDs(ctx, accountID, page) + blockIDs, err := r.GetAccountBlockIDs(ctx, accountID, page) if err != nil { return nil, err } return r.GetBlocksByIDs(ctx, blockIDs) } -func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { - followIDs, err := r.getAccountFollowIDs(ctx, accountID, nil) - return len(followIDs), err -} - -func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) { - followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) - return len(followIDs), err -} - -func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { - followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, nil) - return len(followerIDs), err -} - -func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) { - followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) - return len(followerIDs), err -} - -func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { - followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, nil) - return len(followReqIDs), err -} - -func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { - followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, nil) - return len(followReqIDs), err -} - -func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID string) (int, error) { - blockIDs, err := r.getAccountBlockIDs(ctx, accountID, nil) - return len(blockIDs), err -} - -func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +func (r *relationshipDB) GetAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowIDs, ">"+accountID, page, func() ([]string, error) { var followIDs []string @@ -217,7 +182,7 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri }) } -func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { +func (r *relationshipDB) GetAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { return r.state.Caches.GTS.FollowIDs.Load("l>"+accountID, func() ([]string, error) { var followIDs []string @@ -232,7 +197,7 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID }) } -func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +func (r *relationshipDB) GetAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowIDs, "<"+accountID, page, func() ([]string, error) { var followIDs []string @@ -247,7 +212,7 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st }) } -func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { +func (r *relationshipDB) GetAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { return r.state.Caches.GTS.FollowIDs.Load("l<"+accountID, func() ([]string, error) { var followIDs []string @@ -262,7 +227,7 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account }) } -func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +func (r *relationshipDB) GetAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowRequestIDs, ">"+accountID, page, func() ([]string, error) { var followReqIDs []string @@ -277,7 +242,7 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account }) } -func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +func (r *relationshipDB) GetAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowRequestIDs, "<"+accountID, page, func() ([]string, error) { var followReqIDs []string @@ -292,7 +257,7 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco }) } -func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +func (r *relationshipDB) GetAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.BlockIDs, accountID, page, func() ([]string, error) { var blockIDs []string diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go index 9858e4768..f1d1a35d2 100644 --- a/internal/db/bundb/relationship_test.go +++ b/internal/db/bundb/relationship_test.go @@ -773,20 +773,6 @@ func (suite *RelationshipTestSuite) TestGetAccountFollows() { suite.Len(follows, 2) } -func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() { - account := suite.testAccounts["local_account_1"] - followsCount, err := suite.db.CountAccountLocalFollows(context.Background(), account.ID) - suite.NoError(err) - suite.Equal(2, followsCount) -} - -func (suite *RelationshipTestSuite) TestCountAccountFollows() { - account := suite.testAccounts["local_account_1"] - followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID) - suite.NoError(err) - suite.Equal(2, followsCount) -} - func (suite *RelationshipTestSuite) TestGetAccountFollowers() { account := suite.testAccounts["local_account_1"] follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID, nil) @@ -794,20 +780,6 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowers() { suite.Len(follows, 2) } -func (suite *RelationshipTestSuite) TestCountAccountFollowers() { - account := suite.testAccounts["local_account_1"] - followsCount, err := suite.db.CountAccountFollowers(context.Background(), account.ID) - suite.NoError(err) - suite.Equal(2, followsCount) -} - -func (suite *RelationshipTestSuite) TestCountAccountFollowersLocalOnly() { - account := suite.testAccounts["local_account_1"] - followsCount, err := suite.db.CountAccountLocalFollowers(context.Background(), account.ID) - suite.NoError(err) - suite.Equal(2, followsCount) -} - func (suite *RelationshipTestSuite) TestUnfollowExisting() { originAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["admin_account"] diff --git a/internal/db/bundb/upsert.go b/internal/db/bundb/upsert.go index 34724446c..4a6395179 100644 --- a/internal/db/bundb/upsert.go +++ b/internal/db/bundb/upsert.go @@ -189,14 +189,14 @@ func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) { constraintIDPlaceholders = append(constraintIDPlaceholders, "?") constraintIDs = append(constraintIDs, bun.Ident(constraint)) } - onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update" + onSQL := "CONFLICT (" + strings.Join(constraintIDPlaceholders, ", ") + ") DO UPDATE" setClauses := make([]string, 0, len(columns)) setIDs := make([]interface{}, 0, 2*len(columns)) for _, column := range columns { + setClauses = append(setClauses, "? = ?") // "excluded" is a special table that contains only the row involved in a conflict. - setClauses = append(setClauses, "? = excluded.?") - setIDs = append(setIDs, bun.Ident(column), bun.Ident(column)) + setIDs = append(setIDs, bun.Ident(column), bun.Ident("excluded."+column)) } setSQL := strings.Join(setClauses, ", ") diff --git a/internal/db/relationship.go b/internal/db/relationship.go index 5191701bb..cd4539791 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -140,44 +140,44 @@ type Relationship interface { // GetAccountFollows returns a slice of follows owned by the given accountID. GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) + // GetAccountFollowIDs is like GetAccountFollows, but returns just IDs. + GetAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) + // GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance. GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) + // GetAccountLocalFollowIDs is like GetAccountLocalFollows, but returns just IDs. + GetAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) + // GetAccountFollowers fetches follows that target given accountID. GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) + // GetAccountFollowerIDs is like GetAccountFollowers, but returns just IDs. + GetAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) + // GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance. GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) + // GetAccountLocalFollowerIDs is like GetAccountLocalFollowers, but returns just IDs. + GetAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) + // GetAccountFollowRequests returns all follow requests targeting the given account. GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) + // GetAccountFollowRequestIDs is like GetAccountFollowRequests, but returns just IDs. + GetAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) + // GetAccountFollowRequesting returns all follow requests originating from the given account. GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) + // GetAccountFollowRequestingIDs is like GetAccountFollowRequesting, but returns just IDs. + GetAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) + // GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error) - // CountAccountFollows returns the amount of accounts that the given accountID is following. - CountAccountFollows(ctx context.Context, accountID string) (int, error) - - // CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance. - CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) - - // CountAccountFollowers returns the amounts that the given ID is followed by. - CountAccountFollowers(ctx context.Context, accountID string) (int, error) - - // CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance. - CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) - - // CountAccountFollowRequests returns number of follow requests targeting the given account. - CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) - - // CountAccountFollowerRequests returns number of follow requests originating from the given account. - CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) - - // CountAccountBlocks ... - CountAccountBlocks(ctx context.Context, accountID string) (int, error) + // GetAccountBlockIDs is like GetAccountBlocks, but returns just IDs. + GetAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) // GetNote gets a private note from a source account on a target account, if it exists. GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) |