summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2024-04-16 13:10:13 +0200
committerLibravatar GitHub <noreply@github.com>2024-04-16 13:10:13 +0200
commit3cceed11b28b5f42a653d85ed779d652fd8c26ad (patch)
tree0a7f0994e477609ca705a45f382dfb62056b196e /internal/db
parent[performance] cached oauth database types (#2838) (diff)
downloadgotosocial-3cceed11b28b5f42a653d85ed779d652fd8c26ad.tar.xz
[feature/performance] Store account stats in separate table (#2831)
* [feature/performance] Store account stats in separate table, get stats from remote * test account stats * add some missing increment / decrement calls * change stats function signatures * rejig logging a bit * use lock when updating stats
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/account.go34
-rw-r--r--internal/db/bundb/account.go267
-rw-r--r--internal/db/bundb/account_test.go79
-rw-r--r--internal/db/bundb/migrations/20240414122348_account_stats_model.go52
-rw-r--r--internal/db/bundb/relationship.go63
-rw-r--r--internal/db/bundb/relationship_test.go28
-rw-r--r--internal/db/bundb/upsert.go6
-rw-r--r--internal/db/relationship.go40
8 files changed, 349 insertions, 220 deletions
diff --git a/internal/db/account.go b/internal/db/account.go
index 7cdf7b57f..dec36d2ac 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -20,7 +20,6 @@ package db
import (
"context"
"net/netip"
- "time"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/paging"
@@ -100,12 +99,6 @@ type Account interface {
// GetAccountsUsingEmoji fetches all account models using emoji with given ID stored in their 'emojis' column.
GetAccountsUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Account, error)
- // GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
- CountAccountStatuses(ctx context.Context, accountID string) (int, error)
-
- // CountAccountPinned returns the total number of pinned statuses owned by account with the given id.
- CountAccountPinned(ctx context.Context, accountID string) (int, error)
-
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
@@ -128,13 +121,6 @@ type Account interface {
// In the case of no statuses, this function will return db.ErrNoEntries.
GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error)
- // GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
- //
- // If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned.
- //
- // The returned time will be zero if account has never posted anything.
- GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error)
-
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error
@@ -150,4 +136,24 @@ type Account interface {
// Update local account settings.
UpdateAccountSettings(ctx context.Context, settings *gtsmodel.AccountSettings, columns ...string) error
+
+ // PopulateAccountStats gets (or creates and gets) account stats for
+ // the given account, and attaches them to the account model.
+ PopulateAccountStats(ctx context.Context, account *gtsmodel.Account) error
+
+ // RegenerateAccountStats creates, upserts, and returns stats
+ // for the given account, and attaches them to the account model.
+ //
+ // Unlike GetAccountStats, it will always get the database stats fresh.
+ // This can be used to "refresh" stats.
+ //
+ // Because this involves database calls that can be expensive (on Postgres
+ // specifically), callers should prefer GetAccountStats in 99% of cases.
+ RegenerateAccountStats(ctx context.Context, account *gtsmodel.Account) error
+
+ // Update account stats.
+ UpdateAccountStats(ctx context.Context, stats *gtsmodel.AccountStats, columns ...string) error
+
+ // DeleteAccountStats deletes the accountStats entry for the given accountID.
+ DeleteAccountStats(ctx context.Context, accountID string) error
}
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 45e67c10b..2b3c78aff 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -630,6 +630,13 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
}
}
+ if account.Stats == nil {
+ // Get / Create stats for this account.
+ if err := a.state.DB.PopulateAccountStats(ctx, account); err != nil {
+ errs.Appendf("error populating account stats: %w", err)
+ }
+ }
+
return errs.Combine()
}
@@ -735,31 +742,6 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) error {
})
}
-func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) {
- createdAt := time.Time{}
-
- q := a.db.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
- Column("status.created_at").
- Where("? = ?", bun.Ident("status.account_id"), accountID).
- Order("status.id DESC").
- Limit(1)
-
- if webOnly {
- q = q.
- Where("? IS NULL", bun.Ident("status.in_reply_to_uri")).
- Where("? IS NULL", bun.Ident("status.boost_of_id")).
- Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
- Where("? = ?", bun.Ident("status.federated"), true)
- }
-
- if err := q.Scan(ctx, &createdAt); err != nil {
- return time.Time{}, err
- }
- return createdAt, nil
-}
-
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
if *mediaAttachment.Avatar && *mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
@@ -845,59 +827,6 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
return *faves, nil
}
-func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) {
- counts, err := a.getAccountStatusCounts(ctx, accountID)
- return counts.Statuses, err
-}
-
-func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) {
- counts, err := a.getAccountStatusCounts(ctx, accountID)
- return counts.Pinned, err
-}
-
-func (a *accountDB) getAccountStatusCounts(ctx context.Context, accountID string) (struct {
- Statuses int
- Pinned int
-}, error) {
- // Check for an already cached copy of account status counts.
- counts, ok := a.state.Caches.GTS.AccountCounts.Get(accountID)
- if ok {
- return counts, nil
- }
-
- if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
- var err error
-
- // Scan database for account statuses.
- counts.Statuses, err = tx.NewSelect().
- Table("statuses").
- Where("? = ?", bun.Ident("account_id"), accountID).
- Count(ctx)
- if err != nil {
- return err
- }
-
- // Scan database for pinned statuses.
- counts.Pinned, err = tx.NewSelect().
- Table("statuses").
- Where("? = ?", bun.Ident("account_id"), accountID).
- Where("? IS NOT NULL", bun.Ident("pinned_at")).
- Count(ctx)
- if err != nil {
- return err
- }
-
- return nil
- }); err != nil {
- return counts, err
- }
-
- // Store this account counts result in the cache.
- a.state.Caches.GTS.AccountCounts.Set(accountID, counts)
-
- return counts, nil
-}
-
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) {
// Ensure reasonable
if limit < 0 {
@@ -1147,3 +1076,185 @@ func (a *accountDB) UpdateAccountSettings(
return nil
})
}
+
+func (a *accountDB) PopulateAccountStats(ctx context.Context, account *gtsmodel.Account) error {
+ // Fetch stats from db cache with loader callback.
+ stats, err := a.state.Caches.GTS.AccountStats.LoadOne(
+ "AccountID",
+ func() (*gtsmodel.AccountStats, error) {
+ // Not cached! Perform database query.
+ var stats gtsmodel.AccountStats
+ if err := a.db.
+ NewSelect().
+ Model(&stats).
+ Where("? = ?", bun.Ident("account_stats.account_id"), account.ID).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return &stats, nil
+ },
+ account.ID,
+ )
+
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ // Real error.
+ return err
+ }
+
+ if stats == nil {
+ // Don't have stats yet, generate them.
+ return a.RegenerateAccountStats(ctx, account)
+ }
+
+ // We have a stats, attach
+ // it to the account.
+ account.Stats = stats
+
+ // Check if this is a local
+ // stats by looking at the
+ // account they pertain to.
+ if account.IsRemote() {
+ // Account is remote. Updating
+ // stats for remote accounts is
+ // handled in the dereferencer.
+ //
+ // Nothing more to do!
+ return nil
+ }
+
+ // Stats account is local, check
+ // if we need to regenerate.
+ const statsFreshness = 48 * time.Hour
+ expiry := stats.RegeneratedAt.Add(statsFreshness)
+ if time.Now().After(expiry) {
+ // Stats have expired, regenerate them.
+ return a.RegenerateAccountStats(ctx, account)
+ }
+
+ // Stats are still fresh.
+ return nil
+}
+
+func (a *accountDB) RegenerateAccountStats(ctx context.Context, account *gtsmodel.Account) error {
+ // Initialize a new stats struct.
+ stats := &gtsmodel.AccountStats{
+ AccountID: account.ID,
+ RegeneratedAt: time.Now(),
+ }
+
+ // Count followers outside of transaction since
+ // it uses a cache + requires its own db calls.
+ followerIDs, err := a.state.DB.GetAccountFollowerIDs(ctx, account.ID, nil)
+ if err != nil {
+ return err
+ }
+ stats.FollowersCount = util.Ptr(len(followerIDs))
+
+ // Count following outside of transaction since
+ // it uses a cache + requires its own db calls.
+ followIDs, err := a.state.DB.GetAccountFollowIDs(ctx, account.ID, nil)
+ if err != nil {
+ return err
+ }
+ stats.FollowingCount = util.Ptr(len(followIDs))
+
+ // Count follow requests outside of transaction since
+ // it uses a cache + requires its own db calls.
+ followRequestIDs, err := a.state.DB.GetAccountFollowRequestIDs(ctx, account.ID, nil)
+ if err != nil {
+ return err
+ }
+ stats.FollowRequestsCount = util.Ptr(len(followRequestIDs))
+
+ // Populate remaining stats struct fields.
+ // This can be done inside a transaction.
+ if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ var err error
+
+ // Scan database for account statuses.
+ statusesCount, err := tx.NewSelect().
+ Table("statuses").
+ Where("? = ?", bun.Ident("account_id"), account.ID).
+ Count(ctx)
+ if err != nil {
+ return err
+ }
+ stats.StatusesCount = &statusesCount
+
+ // Scan database for pinned statuses.
+ statusesPinnedCount, err := tx.NewSelect().
+ Table("statuses").
+ Where("? = ?", bun.Ident("account_id"), account.ID).
+ Where("? IS NOT NULL", bun.Ident("pinned_at")).
+ Count(ctx)
+ if err != nil {
+ return err
+ }
+ stats.StatusesPinnedCount = &statusesPinnedCount
+
+ // Scan database for last status.
+ lastStatusAt := time.Time{}
+ err = tx.
+ NewSelect().
+ TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
+ Column("status.created_at").
+ Where("? = ?", bun.Ident("status.account_id"), account.ID).
+ Order("status.id DESC").
+ Limit(1).
+ Scan(ctx, &lastStatusAt)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return err
+ }
+ stats.LastStatusAt = lastStatusAt
+
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ // Upsert this stats in case a race
+ // meant someone else inserted it first.
+ if err := a.state.Caches.GTS.AccountStats.Store(stats, func() error {
+ if _, err := NewUpsert(a.db).
+ Model(stats).
+ Constraint("account_id").
+ Exec(ctx); err != nil {
+ return err
+ }
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ account.Stats = stats
+ return nil
+}
+
+func (a *accountDB) UpdateAccountStats(ctx context.Context, stats *gtsmodel.AccountStats, columns ...string) error {
+ return a.state.Caches.GTS.AccountStats.Store(stats, func() error {
+ if _, err := a.db.
+ NewUpdate().
+ Model(stats).
+ Column(columns...).
+ Where("? = ?", bun.Ident("account_stats.account_id"), stats.AccountID).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
+ })
+}
+
+func (a *accountDB) DeleteAccountStats(ctx context.Context, accountID string) error {
+ defer a.state.Caches.GTS.AccountStats.Invalidate("AccountID", accountID)
+
+ if _, err := a.db.
+ NewDelete().
+ Table("account_stats").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
index dd96543b6..ea211e16f 100644
--- a/internal/db/bundb/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -220,6 +220,8 @@ func (suite *AccountTestSuite) TestGetAccountBy() {
a2.Emojis = nil
a1.Settings = nil
a2.Settings = nil
+ a1.Stats = nil
+ a2.Stats = nil
// Clear database-set fields.
a1.CreatedAt = time.Time{}
@@ -413,18 +415,6 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
suite.WithinDuration(time.Now(), noCache.UpdatedAt, 5*time.Second)
}
-func (suite *AccountTestSuite) TestGetAccountLastPosted() {
- lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, false)
- suite.NoError(err)
- suite.EqualValues(1702200240, lastPosted.Unix())
-}
-
-func (suite *AccountTestSuite) TestGetAccountLastPostedWebOnly() {
- lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, true)
- suite.NoError(err)
- suite.EqualValues(1702200240, lastPosted.Unix())
-}
-
func (suite *AccountTestSuite) TestInsertAccountWithDefaults() {
key, err := rsa.GenerateKey(rand.Reader, 2048)
suite.NoError(err)
@@ -466,22 +456,6 @@ func (suite *AccountTestSuite) TestGetAccountPinnedStatusesNothingPinned() {
suite.Empty(statuses) // This account has nothing pinned.
}
-func (suite *AccountTestSuite) TestCountAccountPinnedSomeResults() {
- testAccount := suite.testAccounts["admin_account"]
-
- pinned, err := suite.db.CountAccountPinned(context.Background(), testAccount.ID)
- suite.NoError(err)
- suite.Equal(pinned, 2) // This account has 2 statuses pinned.
-}
-
-func (suite *AccountTestSuite) TestCountAccountPinnedNothingPinned() {
- testAccount := suite.testAccounts["local_account_1"]
-
- pinned, err := suite.db.CountAccountPinned(context.Background(), testAccount.ID)
- suite.NoError(err)
- suite.Equal(pinned, 0) // This account has nothing pinned.
-}
-
func (suite *AccountTestSuite) TestPopulateAccountWithUnknownMovedToURI() {
testAccount := &gtsmodel.Account{}
*testAccount = *suite.testAccounts["local_account_1"]
@@ -676,6 +650,55 @@ func (suite *AccountTestSuite) TestGetPendingAccounts() {
suite.Len(accounts, 1)
}
+func (suite *AccountTestSuite) TestAccountStatsAll() {
+ ctx := context.Background()
+ for _, account := range suite.testAccounts {
+ // Get stats for the first time. They
+ // should all be generated now since
+ // they're not stored in the test rig.
+ if err := suite.db.PopulateAccountStats(ctx, account); err != nil {
+ suite.FailNow(err.Error())
+ }
+ stats := account.Stats
+ suite.NotNil(stats)
+ suite.WithinDuration(time.Now(), stats.RegeneratedAt, 5*time.Second)
+
+ // Get stats a second time. They shouldn't
+ // be regenerated since we just did it.
+ if err := suite.db.PopulateAccountStats(ctx, account); err != nil {
+ suite.FailNow(err.Error())
+ }
+ stats2 := account.Stats
+ suite.NotNil(stats2)
+ suite.Equal(stats2.RegeneratedAt, stats.RegeneratedAt)
+
+ // Update the stats to indicate they're out of date.
+ stats2.RegeneratedAt = time.Now().Add(-72 * time.Hour)
+ if err := suite.db.UpdateAccountStats(ctx, stats2, "regenerated_at"); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Get stats for a third time, they
+ // should get regenerated now, but
+ // only for local accounts.
+ if err := suite.db.PopulateAccountStats(ctx, account); err != nil {
+ suite.FailNow(err.Error())
+ }
+ stats3 := account.Stats
+ suite.NotNil(stats3)
+ if account.IsLocal() {
+ suite.True(stats3.RegeneratedAt.After(stats.RegeneratedAt))
+ } else {
+ suite.False(stats3.RegeneratedAt.After(stats.RegeneratedAt))
+ }
+
+ // Now delete the stats.
+ if err := suite.db.DeleteAccountStats(ctx, account.ID); err != nil {
+ suite.FailNow(err.Error())
+ }
+ }
+}
+
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}
diff --git a/internal/db/bundb/migrations/20240414122348_account_stats_model.go b/internal/db/bundb/migrations/20240414122348_account_stats_model.go
new file mode 100644
index 000000000..450ca04d4
--- /dev/null
+++ b/internal/db/bundb/migrations/20240414122348_account_stats_model.go
@@ -0,0 +1,52 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Create new AccountStats table.
+ if _, err := tx.
+ NewCreateTable().
+ Model(&gtsmodel.AccountStats{}).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index 1c533af39..052f29cb3 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -112,7 +112,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
}
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) {
- followIDs, err := r.getAccountFollowIDs(ctx, accountID, page)
+ followIDs, err := r.GetAccountFollowIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
@@ -120,7 +120,7 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
}
func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
- followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID)
+ followIDs, err := r.GetAccountLocalFollowIDs(ctx, accountID)
if err != nil {
return nil, err
}
@@ -128,7 +128,7 @@ func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID s
}
func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) {
- followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, page)
+ followerIDs, err := r.GetAccountFollowerIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
@@ -136,7 +136,7 @@ func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID stri
}
func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
- followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID)
+ followerIDs, err := r.GetAccountLocalFollowerIDs(ctx, accountID)
if err != nil {
return nil, err
}
@@ -144,7 +144,7 @@ func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID
}
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) {
- followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, page)
+ followReqIDs, err := r.GetAccountFollowRequestIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
@@ -152,7 +152,7 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID
}
func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) {
- followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, page)
+ followReqIDs, err := r.GetAccountFollowRequestingIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
@@ -160,49 +160,14 @@ func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, account
}
func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) {
- blockIDs, err := r.getAccountBlockIDs(ctx, accountID, page)
+ blockIDs, err := r.GetAccountBlockIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
return r.GetBlocksByIDs(ctx, blockIDs)
}
-func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
- followIDs, err := r.getAccountFollowIDs(ctx, accountID, nil)
- return len(followIDs), err
-}
-
-func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
- followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID)
- return len(followIDs), err
-}
-
-func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
- followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, nil)
- return len(followerIDs), err
-}
-
-func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
- followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID)
- return len(followerIDs), err
-}
-
-func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
- followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, nil)
- return len(followReqIDs), err
-}
-
-func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
- followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, nil)
- return len(followReqIDs), err
-}
-
-func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID string) (int, error) {
- blockIDs, err := r.getAccountBlockIDs(ctx, accountID, nil)
- return len(blockIDs), err
-}
-
-func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+func (r *relationshipDB) GetAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.GTS.FollowIDs, ">"+accountID, page, func() ([]string, error) {
var followIDs []string
@@ -217,7 +182,7 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri
})
}
-func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) {
+func (r *relationshipDB) GetAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) {
return r.state.Caches.GTS.FollowIDs.Load("l>"+accountID, func() ([]string, error) {
var followIDs []string
@@ -232,7 +197,7 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID
})
}
-func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+func (r *relationshipDB) GetAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.GTS.FollowIDs, "<"+accountID, page, func() ([]string, error) {
var followIDs []string
@@ -247,7 +212,7 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st
})
}
-func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) {
+func (r *relationshipDB) GetAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) {
return r.state.Caches.GTS.FollowIDs.Load("l<"+accountID, func() ([]string, error) {
var followIDs []string
@@ -262,7 +227,7 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account
})
}
-func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+func (r *relationshipDB) GetAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.GTS.FollowRequestIDs, ">"+accountID, page, func() ([]string, error) {
var followReqIDs []string
@@ -277,7 +242,7 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account
})
}
-func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+func (r *relationshipDB) GetAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.GTS.FollowRequestIDs, "<"+accountID, page, func() ([]string, error) {
var followReqIDs []string
@@ -292,7 +257,7 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco
})
}
-func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
+func (r *relationshipDB) GetAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.GTS.BlockIDs, accountID, page, func() ([]string, error) {
var blockIDs []string
diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go
index 9858e4768..f1d1a35d2 100644
--- a/internal/db/bundb/relationship_test.go
+++ b/internal/db/bundb/relationship_test.go
@@ -773,20 +773,6 @@ func (suite *RelationshipTestSuite) TestGetAccountFollows() {
suite.Len(follows, 2)
}
-func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
- account := suite.testAccounts["local_account_1"]
- followsCount, err := suite.db.CountAccountLocalFollows(context.Background(), account.ID)
- suite.NoError(err)
- suite.Equal(2, followsCount)
-}
-
-func (suite *RelationshipTestSuite) TestCountAccountFollows() {
- account := suite.testAccounts["local_account_1"]
- followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID)
- suite.NoError(err)
- suite.Equal(2, followsCount)
-}
-
func (suite *RelationshipTestSuite) TestGetAccountFollowers() {
account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID, nil)
@@ -794,20 +780,6 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowers() {
suite.Len(follows, 2)
}
-func (suite *RelationshipTestSuite) TestCountAccountFollowers() {
- account := suite.testAccounts["local_account_1"]
- followsCount, err := suite.db.CountAccountFollowers(context.Background(), account.ID)
- suite.NoError(err)
- suite.Equal(2, followsCount)
-}
-
-func (suite *RelationshipTestSuite) TestCountAccountFollowersLocalOnly() {
- account := suite.testAccounts["local_account_1"]
- followsCount, err := suite.db.CountAccountLocalFollowers(context.Background(), account.ID)
- suite.NoError(err)
- suite.Equal(2, followsCount)
-}
-
func (suite *RelationshipTestSuite) TestUnfollowExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
diff --git a/internal/db/bundb/upsert.go b/internal/db/bundb/upsert.go
index 34724446c..4a6395179 100644
--- a/internal/db/bundb/upsert.go
+++ b/internal/db/bundb/upsert.go
@@ -189,14 +189,14 @@ func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) {
constraintIDPlaceholders = append(constraintIDPlaceholders, "?")
constraintIDs = append(constraintIDs, bun.Ident(constraint))
}
- onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update"
+ onSQL := "CONFLICT (" + strings.Join(constraintIDPlaceholders, ", ") + ") DO UPDATE"
setClauses := make([]string, 0, len(columns))
setIDs := make([]interface{}, 0, 2*len(columns))
for _, column := range columns {
+ setClauses = append(setClauses, "? = ?")
// "excluded" is a special table that contains only the row involved in a conflict.
- setClauses = append(setClauses, "? = excluded.?")
- setIDs = append(setIDs, bun.Ident(column), bun.Ident(column))
+ setIDs = append(setIDs, bun.Ident(column), bun.Ident("excluded."+column))
}
setSQL := strings.Join(setClauses, ", ")
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index 5191701bb..cd4539791 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -140,44 +140,44 @@ type Relationship interface {
// GetAccountFollows returns a slice of follows owned by the given accountID.
GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error)
+ // GetAccountFollowIDs is like GetAccountFollows, but returns just IDs.
+ GetAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
+
// GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance.
GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
+ // GetAccountLocalFollowIDs is like GetAccountLocalFollows, but returns just IDs.
+ GetAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error)
+
// GetAccountFollowers fetches follows that target given accountID.
GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error)
+ // GetAccountFollowerIDs is like GetAccountFollowers, but returns just IDs.
+ GetAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
+
// GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance.
GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
+ // GetAccountLocalFollowerIDs is like GetAccountLocalFollowers, but returns just IDs.
+ GetAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error)
+
// GetAccountFollowRequests returns all follow requests targeting the given account.
GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error)
+ // GetAccountFollowRequestIDs is like GetAccountFollowRequests, but returns just IDs.
+ GetAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
+
// GetAccountFollowRequesting returns all follow requests originating from the given account.
GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error)
+ // GetAccountFollowRequestingIDs is like GetAccountFollowRequesting, but returns just IDs.
+ GetAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
+
// GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters.
GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error)
- // CountAccountFollows returns the amount of accounts that the given accountID is following.
- CountAccountFollows(ctx context.Context, accountID string) (int, error)
-
- // CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance.
- CountAccountLocalFollows(ctx context.Context, accountID string) (int, error)
-
- // CountAccountFollowers returns the amounts that the given ID is followed by.
- CountAccountFollowers(ctx context.Context, accountID string) (int, error)
-
- // CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance.
- CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error)
-
- // CountAccountFollowRequests returns number of follow requests targeting the given account.
- CountAccountFollowRequests(ctx context.Context, accountID string) (int, error)
-
- // CountAccountFollowerRequests returns number of follow requests originating from the given account.
- CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error)
-
- // CountAccountBlocks ...
- CountAccountBlocks(ctx context.Context, accountID string) (int, error)
+ // GetAccountBlockIDs is like GetAccountBlocks, but returns just IDs.
+ GetAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error)
// GetNote gets a private note from a source account on a target account, if it exists.
GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error)