diff options
Diffstat (limited to 'internal/db')
25 files changed, 2261 insertions, 951 deletions
diff --git a/internal/db/account.go b/internal/db/account.go index 6ecfea018..4a08918b0 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -41,6 +41,21 @@ type Account interface { // GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong. GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error) + // GetAccountByInboxURI returns one account with the given inbox_uri, or an error if something goes wrong. + GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) + + // GetAccountByOutboxURI returns one account with the given outbox_uri, or an error if something goes wrong. + GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) + + // GetAccountByFollowingURI returns one account with the given following_uri, or an error if something goes wrong. + GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) + + // GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong. + GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) + + // PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc). + PopulateAccount(ctx context.Context, account *gtsmodel.Account) error + // PutAccount puts one account in the database. PutAccount(ctx context.Context, account *gtsmodel.Account) Error diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index df73168e2..ccf7aaa46 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -20,11 +20,13 @@ package bundb import ( "context" "errors" + "fmt" "strings" "time" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" @@ -37,18 +39,15 @@ type accountDB struct { state *state.State } -func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { - return a.conn. - NewSelect(). - Model(account) -} - func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { return a.getAccount( ctx, "ID", func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx) + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.id"), id). + Scan(ctx) }, id, ) @@ -59,7 +58,10 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. ctx, "URI", func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx) + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.uri"), uri). + Scan(ctx) }, uri, ) @@ -70,7 +72,10 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel. ctx, "URL", func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx) + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.url"), url). + Scan(ctx) }, url, ) @@ -81,7 +86,8 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str ctx, "Username.Domain", func(account *gtsmodel.Account) error { - q := a.newAccountQ(account) + q := a.conn.NewSelect(). + Model(account) if domain != "" { q = q. @@ -105,12 +111,71 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo ctx, "PublicKeyURI", func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx) + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.public_key_uri"), id). + Scan(ctx) }, id, ) } +func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + "InboxURI", + func(account *gtsmodel.Account) error { + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.inbox_uri"), uri). + Scan(ctx) + }, + uri, + ) +} + +func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + "OutboxURI", + func(account *gtsmodel.Account) error { + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.outbox_uri"), uri). + Scan(ctx) + }, + uri, + ) +} + +func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + "FollowersURI", + func(account *gtsmodel.Account) error { + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.followers_uri"), uri). + Scan(ctx) + }, + uri, + ) +} + +func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + return a.getAccount( + ctx, + "FollowingURI", + func(account *gtsmodel.Account) error { + return a.conn.NewSelect(). + Model(account). + Where("? = ?", bun.Ident("account.following_uri"), uri). + Scan(ctx) + }, + uri, + ) +} + func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { var username string @@ -141,31 +206,56 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func( return nil, err } - if account.AvatarMediaAttachmentID != "" { - // Set the account's related avatar - account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID) + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return account, nil + } + + // Further populate the account fields where applicable. + if err := a.PopulateAccount(ctx, account); err != nil { + return nil, err + } + + return account, nil +} + +func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Account) error { + var err error + + if account.AvatarMediaAttachment == nil && account.AvatarMediaAttachmentID != "" { + // Account avatar attachment is not set, fetch from database. + account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID( + ctx, // these are already barebones + account.AvatarMediaAttachmentID, + ) if err != nil { - log.Errorf(ctx, "error getting account %s avatar: %v", account.ID, err) + return fmt.Errorf("error populating account avatar: %w", err) } } - if account.HeaderMediaAttachmentID != "" { - // Set the account's related header - account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.HeaderMediaAttachmentID) + if account.HeaderMediaAttachment == nil && account.HeaderMediaAttachmentID != "" { + // Account header attachment is not set, fetch from database. + account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID( + ctx, // these are already barebones + account.HeaderMediaAttachmentID, + ) if err != nil { - log.Errorf(ctx, "error getting account %s header: %v", account.ID, err) + return fmt.Errorf("error populating account header: %w", err) } } - if len(account.EmojiIDs) > 0 { - // Set the account's related emojis - account.Emojis, err = a.state.DB.GetEmojisByIDs(ctx, account.EmojiIDs) + if !account.EmojisPopulated() { + // Account emojis are out-of-date with IDs, repopulate. + account.Emojis, err = a.state.DB.GetEmojisByIDs( + ctx, // these are already barebones + account.EmojiIDs, + ) if err != nil { - log.Errorf(ctx, "error getting account %s emojis: %v", account.ID, err) + return fmt.Errorf("error populating account emojis: %w", err) } } - return account, nil + return nil } func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error { @@ -198,7 +288,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account columns = append(columns, "updated_at") } - return a.state.Caches.GTS.Account().Store(account, func() error { + err := a.state.Caches.GTS.Account().Store(account, func() error { // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // @@ -234,6 +324,11 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account return err }) }) + if err != nil { + return err + } + + return nil } func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { @@ -258,7 +353,9 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { return err } + // Invalidate account from database lookups. a.state.Caches.GTS.Account().Invalidate("ID", id) + return nil } diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index b7e8aaadc..2241ab783 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -21,6 +21,8 @@ import ( "context" "crypto/rand" "crypto/rsa" + "errors" + "reflect" "strings" "testing" "time" @@ -61,44 +63,149 @@ func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() { suite.Len(statuses, 1) } -func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { - account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID) - if err != nil { - suite.FailNow(err.Error()) +func (suite *AccountTestSuite) TestGetAccountBy() { + t := suite.T() + + // Create a new context for this test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Sentinel error to mark avoiding a test case. + sentinelErr := errors.New("sentinel") + + // isEqual checks if 2 account models are equal. + isEqual := func(a1, a2 gtsmodel.Account) bool { + // Clear populated sub-models. + a1.HeaderMediaAttachment = nil + a2.HeaderMediaAttachment = nil + a1.AvatarMediaAttachment = nil + a2.AvatarMediaAttachment = nil + a1.Emojis = nil + a2.Emojis = nil + + // Clear database-set fields. + a1.CreatedAt = time.Time{} + a2.CreatedAt = time.Time{} + a1.UpdatedAt = time.Time{} + a2.UpdatedAt = time.Time{} + + // Manually compare keys. + pk1 := a1.PublicKey + pv1 := a1.PrivateKey + pk2 := a2.PublicKey + pv2 := a2.PrivateKey + a1.PublicKey = nil + a1.PrivateKey = nil + a2.PublicKey = nil + a2.PrivateKey = nil + + return reflect.DeepEqual(a1, a2) && + ((pk1 == nil && pk2 == nil) || pk1.Equal(pk2)) && + ((pv1 == nil && pv2 == nil) || pv1.Equal(pv2)) } - suite.NotNil(account) - suite.NotNil(account.AvatarMediaAttachment) - suite.NotEmpty(account.AvatarMediaAttachment.URL) - suite.NotNil(account.HeaderMediaAttachment) - suite.NotEmpty(account.HeaderMediaAttachment.URL) -} - -func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() { - testAccount1 := suite.testAccounts["local_account_1"] - account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain) - suite.NoError(err) - suite.NotNil(account1) - - testAccount2 := suite.testAccounts["remote_account_1"] - account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain) - suite.NoError(err) - suite.NotNil(account2) -} - -func (suite *AccountTestSuite) TestGetAccountByUsernameDomainMixedCase() { - testAccount := suite.testAccounts["remote_account_2"] - account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount.Username, testAccount.Domain) - suite.NoError(err) - suite.NotNil(account1) - - account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToUpper(testAccount.Username), testAccount.Domain) - suite.NoError(err) - suite.NotNil(account2) - - account3, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToLower(testAccount.Username), testAccount.Domain) - suite.NoError(err) - suite.NotNil(account3) + for _, account := range suite.testAccounts { + for lookup, dbfunc := range map[string]func() (*gtsmodel.Account, error){ + "id": func() (*gtsmodel.Account, error) { + return suite.db.GetAccountByID(ctx, account.ID) + }, + + "uri": func() (*gtsmodel.Account, error) { + return suite.db.GetAccountByURI(ctx, account.URI) + }, + + "url": func() (*gtsmodel.Account, error) { + if account.URL == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByURL(ctx, account.URL) + }, + + "username@domain": func() (*gtsmodel.Account, error) { + return suite.db.GetAccountByUsernameDomain(ctx, account.Username, account.Domain) + }, + + "username_upper@domain": func() (*gtsmodel.Account, error) { + return suite.db.GetAccountByUsernameDomain(ctx, strings.ToUpper(account.Username), account.Domain) + }, + + "username_lower@domain": func() (*gtsmodel.Account, error) { + return suite.db.GetAccountByUsernameDomain(ctx, strings.ToLower(account.Username), account.Domain) + }, + + "public_key_uri": func() (*gtsmodel.Account, error) { + if account.PublicKeyURI == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByPubkeyID(ctx, account.PublicKeyURI) + }, + + "inbox_uri": func() (*gtsmodel.Account, error) { + if account.InboxURI == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByInboxURI(ctx, account.InboxURI) + }, + + "outbox_uri": func() (*gtsmodel.Account, error) { + if account.OutboxURI == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByOutboxURI(ctx, account.OutboxURI) + }, + + "following_uri": func() (*gtsmodel.Account, error) { + if account.FollowingURI == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByFollowingURI(ctx, account.FollowingURI) + }, + + "followers_uri": func() (*gtsmodel.Account, error) { + if account.FollowersURI == "" { + return nil, sentinelErr + } + return suite.db.GetAccountByFollowersURI(ctx, account.FollowersURI) + }, + } { + + // Clear database caches. + suite.state.Caches.Init() + + t.Logf("checking database lookup %q", lookup) + + // Perform database function. + checkAcc, err := dbfunc() + if err != nil { + if err == sentinelErr { + continue + } + + t.Errorf("error encountered for database lookup %q: %v", lookup, err) + continue + } + + // Check received account data. + if !isEqual(*checkAcc, *account) { + t.Errorf("account does not contain expected data: %+v", checkAcc) + continue + } + + // Check that avatar attachment populated. + if account.AvatarMediaAttachmentID != "" && + (checkAcc.AvatarMediaAttachment == nil || checkAcc.AvatarMediaAttachment.ID != account.AvatarMediaAttachmentID) { + t.Errorf("account avatar media attachment not correctly populated for: %+v", account) + continue + } + + // Check that header attachment populated. + if account.HeaderMediaAttachmentID != "" && + (checkAcc.HeaderMediaAttachment == nil || checkAcc.HeaderMediaAttachment.ID != account.HeaderMediaAttachmentID) { + t.Errorf("account header media attachment not correctly populated for: %+v", account) + continue + } + } + } } func (suite *AccountTestSuite) TestUpdateAccount() { diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go index f238f2273..a24deac9e 100644 --- a/internal/db/bundb/basic_test.go +++ b/internal/db/bundb/basic_test.go @@ -19,6 +19,8 @@ package bundb_test import ( "context" + "crypto/rand" + "crypto/rsa" "testing" "time" @@ -40,6 +42,12 @@ func (suite *BasicTestSuite) TestGetAccountByID() { } func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + suite.FailNow(err.Error()) + } + + // Create an account that only just matches constraints. testAccount := >smodel.Account{ ID: "01GADR1AH9VCKH8YYCM86XSZ00", Username: "test", @@ -49,6 +57,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() { OutboxURI: "https://example.org/users/test/outbox", ActorType: "Person", PublicKeyURI: "https://example.org/test#main-key", + PublicKey: &key.PublicKey, } if err := suite.db.Put(context.Background(), testAccount); err != nil { @@ -99,7 +108,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() { suite.Empty(a.FeaturedCollectionURI) suite.Equal(testAccount.ActorType, a.ActorType) suite.Nil(a.PrivateKey) - suite.Nil(a.PublicKey) + suite.EqualValues(key.PublicKey, *a.PublicKey) suite.Equal(testAccount.PublicKeyURI, a.PublicKeyURI) suite.Zero(a.SensitizedAt) suite.Zero(a.SilencedAt) diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index 1a9d3be05..d17d64b35 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -47,6 +47,24 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M ) } +func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) { + attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids)) + + for _, id := range ids { + // Attempt fetch from DB + attachment, err := m.GetAttachmentByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error getting attachment %q: %v", id, err) + continue + } + + // Append attachment + attachments = append(attachments, attachment) + } + + return attachments, nil +} + func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, db.Error) { return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) { var attachment gtsmodel.MediaAttachment @@ -118,7 +136,7 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l return nil, m.conn.ProcessError(err) } - return m.getAttachments(ctx, attachmentIDs) + return m.GetAttachmentsByIDs(ctx, attachmentIDs) } func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) { @@ -163,7 +181,7 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit return nil, m.conn.ProcessError(err) } - return m.getAttachments(ctx, attachmentIDs) + return m.GetAttachmentsByIDs(ctx, attachmentIDs) } func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) { @@ -189,7 +207,7 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim return nil, m.conn.ProcessError(err) } - return m.getAttachments(ctx, attachmentIDs) + return m.GetAttachmentsByIDs(ctx, attachmentIDs) } func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) { @@ -211,21 +229,3 @@ func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan t return count, nil } - -func (m *mediaDB) getAttachments(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, db.Error) { - attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids)) - - for _, id := range ids { - // Attempt fetch from DB - attachment, err := m.GetAttachmentByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting attachment %q: %v", id, err) - continue - } - - // Append attachment - attachments = append(attachments, attachment) - } - - return attachments, nil -} diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index 3a543f3c2..e64d6dac4 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -19,8 +19,10 @@ package bundb import ( "context" + "fmt" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" @@ -32,20 +34,13 @@ type mentionDB struct { state *state.State } -func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { - return m.conn. - NewSelect(). - Model(i). - Relation("Status"). - Relation("OriginAccount"). - Relation("TargetAccount") -} - func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { - return m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) { + mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) { var mention gtsmodel.Mention - q := m.newMentionQ(&mention). + q := m.conn. + NewSelect(). + Model(&mention). Where("? = ?", bun.Ident("mention.id"), id) if err := q.Scan(ctx); err != nil { @@ -54,6 +49,38 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio return &mention, nil }, id) + if err != nil { + return nil, err + } + + // Set the mention originating status. + mention.Status, err = m.state.DB.GetStatusByID( + gtscontext.SetBarebones(ctx), + mention.StatusID, + ) + if err != nil { + return nil, fmt.Errorf("error populating mention status: %w", err) + } + + // Set the mention origin account model. + mention.OriginAccount, err = m.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + mention.OriginAccountID, + ) + if err != nil { + return nil, fmt.Errorf("error populating mention origin account: %w", err) + } + + // Set the mention target account model. + mention.TargetAccount, err = m.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + mention.TargetAccountID, + ) + if err != nil { + return nil, fmt.Errorf("error populating mention target account: %w", err) + } + + return mention, nil } func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) { @@ -73,3 +100,25 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel. return mentions, nil } + +func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error { + return m.state.Caches.GTS.Mention().Store(mention, func() error { + _, err := m.conn.NewInsert().Model(mention).Exec(ctx) + return m.conn.ProcessError(err) + }) +} + +func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error { + if _, err := m.conn. + NewDelete(). + Table("mentions"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return m.conn.ProcessError(err) + } + + // Invalidate mention from the lookup cache. + m.state.Caches.GTS.Mention().Invalidate("ID", id) + + return nil +} diff --git a/internal/db/bundb/migrations/20230328105630_chore_refactoring.go b/internal/db/bundb/migrations/20230328105630_chore_refactoring.go new file mode 100644 index 000000000..3bf9d59ef --- /dev/null +++ b/internal/db/bundb/migrations/20230328105630_chore_refactoring.go @@ -0,0 +1,167 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + // To update unique constraint on public key, we need to migrate accounts into a new table. + // See section 7 here: https://www.sqlite.org/lang_altertable.html + + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Create the new accounts table. + if _, err := tx. + NewCreateTable(). + ModelTableExpr("new_accounts"). + Model(>smodel.Account{}). + Exec(ctx); err != nil { + return err + } + + // If we don't specify columns explicitly, + // Postgres gives the following error when + // transferring accounts to new_accounts: + // + // ERROR: column "fetched_at" is of type timestamp with time zone but expression is of type character varying at character 35 + // HINT: You will need to rewrite or cast the expression. + // + // Rather than do funky casting to fix this, + // it's simpler to just specify all columns. + columns := []string{ + "id", + "created_at", + "updated_at", + "fetched_at", + "username", + "domain", + "avatar_media_attachment_id", + "avatar_remote_url", + "header_media_attachment_id", + "header_remote_url", + "display_name", + "emojis", + "fields", + "note", + "note_raw", + "memorial", + "also_known_as", + "moved_to_account_id", + "bot", + "reason", + "locked", + "discoverable", + "privacy", + "sensitive", + "language", + "status_content_type", + "custom_css", + "uri", + "url", + "inbox_uri", + "shared_inbox_uri", + "outbox_uri", + "following_uri", + "followers_uri", + "featured_collection_uri", + "actor_type", + "private_key", + "public_key", + "public_key_uri", + "sensitized_at", + "silenced_at", + "suspended_at", + "hide_collections", + "suspension_origin", + "enable_rss", + } + + // Copy all accounts to the new table. + if _, err := tx. + NewInsert(). + Table("new_accounts"). + Table("accounts"). + Column(columns...). + Exec(ctx); err != nil { + return err + } + + // Drop the old table. + if _, err := tx. + NewDropTable(). + Table("accounts"). + Exec(ctx); err != nil { + return err + } + + // Rename new table to old table. + if _, err := tx. + ExecContext( + ctx, + "ALTER TABLE ? RENAME TO ?", + bun.Ident("new_accounts"), + bun.Ident("accounts"), + ); err != nil { + return err + } + + // Add all account indexes to the new table. + for index, columns := range map[string][]string{ + // Standard indices. + "accounts_id_idx": {"id"}, + "accounts_suspended_at_idx": {"suspended_at"}, + "accounts_domain_idx": {"domain"}, + "accounts_username_domain_idx": {"username", "domain"}, + // URI indices. + "accounts_uri_idx": {"uri"}, + "accounts_url_idx": {"url"}, + "accounts_inbox_uri_idx": {"inbox_uri"}, + "accounts_outbox_uri_idx": {"outbox_uri"}, + "accounts_followers_uri_idx": {"followers_uri"}, + "accounts_following_uri_idx": {"following_uri"}, + "accounts_public_key_uri_idx": {"public_key_uri"}, + } { + if _, err := tx. + NewCreateIndex(). + Table("accounts"). + Index(index). + Column(columns...). + Exec(ctx); err != nil { + return err + } + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index b1e7f45ff..f32aed092 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -33,7 +33,7 @@ type notificationDB struct { state *state.State } -func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { +func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) { var notif gtsmodel.Notification @@ -48,7 +48,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo }, id) } -func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { +func (n *notificationDB) GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { // Ensure reasonable if limit < 0 { limit = 0 @@ -92,7 +92,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, // reason for this is that for each notif, we can instead get it from our cache if it's cached for _, id := range notifIDs { // Attempt fetch from DB - notif, err := n.GetNotification(ctx, id) + notif, err := n.GetNotificationByID(ctx, id) if err != nil { log.Errorf(ctx, "error getting notification %q: %v", id, err) continue @@ -105,7 +105,14 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, return notifs, nil } -func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.Error { +func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error { + return n.state.Caches.GTS.Notification().Store(notif, func() error { + _, err := n.conn.NewInsert().Model(notif).Exec(ctx) + return n.conn.ProcessError(err) + }) +} + +func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) db.Error { if _, err := n.conn. NewDelete(). TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). @@ -118,19 +125,23 @@ func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.E return nil } -func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) db.Error { +func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) db.Error { if targetAccountID == "" && originAccountID == "" { return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set") } // Capture notification IDs in a RETURNING statement. - ids := []string{} + var ids []string q := n.conn. NewDelete(). TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). Returning("?", bun.Ident("id")) + if len(types) > 0 { + q = q.Where("? IN (?)", bun.Ident("notification.notification_type"), bun.In(types)) + } + if targetAccountID != "" { q = q.Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID) } @@ -153,7 +164,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountI func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) db.Error { // Capture notification IDs in a RETURNING statement. - ids := []string{} + var ids []string q := n.conn. NewDelete(). diff --git a/internal/db/bundb/notification_test.go b/internal/db/bundb/notification_test.go index 117fc329c..bdee911b3 100644 --- a/internal/db/bundb/notification_test.go +++ b/internal/db/bundb/notification_test.go @@ -85,11 +85,11 @@ type NotificationTestSuite struct { BunDBStandardTestSuite } -func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() { +func (suite *NotificationTestSuite) TestGetAccountNotificationsWithSpam() { suite.spamNotifs() testAccount := suite.testAccounts["local_account_1"] before := time.Now() - notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest) + notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest) suite.NoError(err) timeTaken := time.Since(before) fmt.Printf("\n\n\n withSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken) @@ -100,10 +100,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() { } } -func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() { +func (suite *NotificationTestSuite) TestGetAccountNotificationsWithoutSpam() { testAccount := suite.testAccounts["local_account_1"] before := time.Now() - notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest) + notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest) suite.NoError(err) timeTaken := time.Since(before) fmt.Printf("\n\n\n withoutSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken) @@ -117,10 +117,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() { func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() { suite.spamNotifs() testAccount := suite.testAccounts["local_account_1"] - err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "") + err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "") suite.NoError(err) - notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest) + notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest) suite.NoError(err) suite.NotNil(notifications) suite.Empty(notifications) @@ -129,10 +129,10 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() { func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() { suite.spamNotifs() testAccount := suite.testAccounts["local_account_1"] - err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "") + err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "") suite.NoError(err) - notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest) + notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest) suite.NoError(err) suite.NotNil(notifications) suite.Empty(notifications) @@ -146,7 +146,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() { func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAccount() { testAccount := suite.testAccounts["local_account_2"] - if err := suite.db.DeleteNotifications(context.Background(), "", testAccount.ID); err != nil { + if err := suite.db.DeleteNotifications(context.Background(), nil, "", testAccount.ID); err != nil { suite.FailNow(err.Error()) } @@ -166,7 +166,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAndTar originAccount := suite.testAccounts["local_account_2"] targetAccount := suite.testAccounts["admin_account"] - if err := suite.db.DeleteNotifications(context.Background(), targetAccount.ID, originAccount.ID); err != nil { + if err := suite.db.DeleteNotifications(context.Background(), nil, targetAccount.ID, originAccount.ID); err != nil { suite.FailNow(err.Error()) } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 21a29b5dc..82559a213 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -23,8 +23,8 @@ import ( "fmt" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/uptrace/bun" ) @@ -34,603 +34,212 @@ type relationshipDB struct { state *state.State } -func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) { - // Look for a block in direction of account1->account2 - block1, err := r.getBlock(ctx, account1, account2) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return false, err - } - - if block1 != nil { - // account1 blocks account2 - return true, nil - } else if !eitherDirection { - // Don't check for mutli-directional - return false, nil - } - - // Look for a block in direction of account2->account1 - block2, err := r.getBlock(ctx, account2, account1) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return false, err - } - - return (block2 != nil), nil -} - -func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) { - // Fetch block from database - block, err := r.getBlock(ctx, account1, account2) - if err != nil { - return nil, err - } - - // Set the block originating account - block.Account, err = r.state.DB.GetAccountByID(ctx, block.AccountID) - if err != nil { - return nil, err - } - - // Set the block target account - block.TargetAccount, err = r.state.DB.GetAccountByID(ctx, block.TargetAccountID) - if err != nil { - return nil, err - } - - return block, nil -} - -func (r *relationshipDB) getBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) { - return r.state.Caches.GTS.Block().Load("AccountID.TargetAccountID", func() (*gtsmodel.Block, error) { - var block gtsmodel.Block - - q := r.conn.NewSelect().Model(&block). - Where("? = ?", bun.Ident("block.account_id"), account1). - Where("? = ?", bun.Ident("block.target_account_id"), account2) - if err := q.Scan(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } - - return &block, nil - }, account1, account2) -} - -func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) db.Error { - return r.state.Caches.GTS.Block().Store(block, func() error { - _, err := r.conn.NewInsert().Model(block).Exec(ctx) - return r.conn.ProcessError(err) - }) -} - -func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) db.Error { - if _, err := r.conn. - NewDelete(). - TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). - Where("? = ?", bun.Ident("block.id"), id). - Exec(ctx); err != nil { - return r.conn.ProcessError(err) - } - - // Drop any old value from cache by this ID - r.state.Caches.GTS.Block().Invalidate("ID", id) - return nil -} - -func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) db.Error { - if _, err := r.conn. - NewDelete(). - TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). - Where("? = ?", bun.Ident("block.uri"), uri). - Exec(ctx); err != nil { - return r.conn.ProcessError(err) - } - - // Drop any old value from cache by this URI - r.state.Caches.GTS.Block().Invalidate("URI", uri) - return nil -} - -func (r *relationshipDB) DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) db.Error { - blockIDs := []string{} - - q := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). - Column("block.id"). - Where("? = ?", bun.Ident("block.account_id"), originAccountID) - - if err := q.Scan(ctx, &blockIDs); err != nil { - return r.conn.ProcessError(err) - } - - for _, blockID := range blockIDs { - if err := r.DeleteBlockByID(ctx, blockID); err != nil { - return err - } - } - - return nil -} - -func (r *relationshipDB) DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) db.Error { - blockIDs := []string{} - - q := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). - Column("block.id"). - Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID) - - if err := q.Scan(ctx, &blockIDs); err != nil { - return r.conn.ProcessError(err) - } - - for _, blockID := range blockIDs { - if err := r.DeleteBlockByID(ctx, blockID); err != nil { - return err - } - } - - return nil -} - func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { - rel := >smodel.Relationship{ - ID: targetAccount, + var rel gtsmodel.Relationship + rel.ID = targetAccount + + // check if the requesting follows the target + follow, err := r.GetFollow( + gtscontext.SetBarebones(ctx), + requestingAccount, + targetAccount, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err) } - // check if the requesting account follows the target account - follow := >smodel.Follow{} - if err := r.conn. - NewSelect(). - Model(follow). - Column("follow.show_reblogs", "follow.notify"). - Where("? = ?", bun.Ident("follow.account_id"), requestingAccount). - Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount). - Limit(1). - Scan(ctx); err != nil { - if err := r.conn.ProcessError(err); err != db.ErrNoEntries { - return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err) - } - // no follow exists so these are all false - rel.Following = false - rel.ShowingReblogs = false - rel.Notifying = false - } else { + if follow != nil { // follow exists so we can fill these fields out... rel.Following = true rel.ShowingReblogs = *follow.ShowReblogs rel.Notifying = *follow.Notify } - // check if the target account follows the requesting account - followedByQ := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). - Column("follow.id"). - Where("? = ?", bun.Ident("follow.account_id"), targetAccount). - Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount) - followedBy, err := r.conn.Exists(ctx, followedByQ) + // check if the target follows the requesting + rel.FollowedBy, err = r.IsFollowing(ctx, + targetAccount, + requestingAccount, + ) if err != nil { - return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err) } - rel.FollowedBy = followedBy - // check if there's a pending following request from requesting account to target account - requestedQ := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). - Column("follow_request.id"). - Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount). - Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount) - requested, err := r.conn.Exists(ctx, requestedQ) + // check if requesting has follow requested target + rel.Requested, err = r.IsFollowRequested(ctx, + requestingAccount, + targetAccount, + ) if err != nil { - return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err) } - rel.Requested = requested // check if the requesting account is blocking the target account - blockA2T, err := r.getBlock(ctx, requestingAccount, targetAccount) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err) - } - rel.Blocking = (blockA2T != nil) - - // check if the requesting account is blocked by the target account - blockT2A, err := r.getBlock(ctx, targetAccount, requestingAccount) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err) - } - rel.BlockedBy = (blockT2A != nil) - - return rel, nil -} - -func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { - if sourceAccount == nil || targetAccount == nil { - return false, nil - } - - q := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). - Column("follow.id"). - Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID). - Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID) - - return r.conn.Exists(ctx, q) -} - -func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { - if sourceAccount == nil || targetAccount == nil { - return false, nil - } - - q := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). - Column("follow_request.id"). - Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID). - Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID) - - return r.conn.Exists(ctx, q) -} - -func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { - if account1 == nil || account2 == nil { - return false, nil - } - - // make sure account 1 follows account 2 - f1, err := r.IsFollowing(ctx, account1, account2) + rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount) if err != nil { - return false, err + return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err) } - // make sure account 2 follows account 1 - f2, err := r.IsFollowing(ctx, account2, account1) + // check if the requesting account is blocked by the target account + rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount) if err != nil { - return false, err + return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err) } - return f1 && f2, nil + return &rel, nil } -func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { - // Get original follow request. - var followRequestID string - if err := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). - Column("follow_request.id"). - Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). - Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). - Scan(ctx, &followRequestID); err != nil { +func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { + var followIDs []string + if err := newSelectFollows(r.conn, accountID). + Scan(ctx, &followIDs); err != nil { return nil, r.conn.ProcessError(err) } - - followRequest, err := r.getFollowRequest(ctx, followRequestID) - if err != nil { - return nil, r.conn.ProcessError(err) - } - - // Create a new follow to 'replace' - // the original follow request with. - follow := >smodel.Follow{ - ID: followRequest.ID, - AccountID: originAccountID, - Account: followRequest.Account, - TargetAccountID: targetAccountID, - TargetAccount: followRequest.TargetAccount, - URI: followRequest.URI, - } - - // If the follow already exists, just - // replace the URI with the new one. - if _, err := r.conn. - NewInsert(). - Model(follow). - On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). - Exec(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } - - // Delete original follow request. - if _, err := r.conn. - NewDelete(). - TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). - Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). - Exec(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } - - // Delete original follow request notification. - if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil { - return nil, err - } - - // return the new follow - return follow, nil + return r.GetFollowsByIDs(ctx, followIDs) } -func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) { - // Get original follow request. - var followRequestID string - if err := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). - Column("follow_request.id"). - Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). - Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). - Scan(ctx, &followRequestID); err != nil { +func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { + var followIDs []string + if err := newSelectLocalFollows(r.conn, accountID). + Scan(ctx, &followIDs); err != nil { return nil, r.conn.ProcessError(err) } + return r.GetFollowsByIDs(ctx, followIDs) +} - followRequest, err := r.getFollowRequest(ctx, followRequestID) - if err != nil { +func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { + var followIDs []string + if err := newSelectFollowers(r.conn, accountID). + Scan(ctx, &followIDs); err != nil { return nil, r.conn.ProcessError(err) } + return r.GetFollowsByIDs(ctx, followIDs) +} - // Delete original follow request. - if _, err := r.conn. - NewDelete(). - TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). - Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). - Exec(ctx); err != nil { +func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { + var followIDs []string + if err := newSelectLocalFollowers(r.conn, accountID). + Scan(ctx, &followIDs); err != nil { return nil, r.conn.ProcessError(err) } - - // Delete original follow request notification. - if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil { - return nil, err - } - - // Return the now deleted follow request. - return followRequest, nil + return r.GetFollowsByIDs(ctx, followIDs) } -func (r *relationshipDB) deleteFollowRequestNotif(ctx context.Context, originAccountID string, targetAccountID string) db.Error { - var id string - if err := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). - Column("notification.id"). - Where("? = ?", bun.Ident("notification.origin_account_id"), originAccountID). - Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID). - Where("? = ?", bun.Ident("notification.notification_type"), gtsmodel.NotificationFollowRequest). - Limit(1). // There should only be one! - Scan(ctx, &id); err != nil { - err = r.conn.ProcessError(err) - if errors.Is(err, db.ErrNoEntries) { - // If no entries, the notif didn't - // exist anyway so nothing to do here. - return nil - } - // Return on real error. - return err - } - - return r.state.DB.DeleteNotification(ctx, id) +func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { + n, err := newSelectFollows(r.conn, accountID).Count(ctx) + return n, r.conn.ProcessError(err) } -func (r *relationshipDB) getFollow(ctx context.Context, id string) (*gtsmodel.Follow, db.Error) { - follow := >smodel.Follow{} - - err := r.conn. - NewSelect(). - Model(follow). - Where("? = ?", bun.Ident("follow.id"), id). - Scan(ctx) - if err != nil { - return nil, r.conn.ProcessError(err) - } - - follow.Account, err = r.state.DB.GetAccountByID(ctx, follow.AccountID) - if err != nil { - log.Errorf(ctx, "error getting follow account %q: %v", follow.AccountID, err) - } - - follow.TargetAccount, err = r.state.DB.GetAccountByID(ctx, follow.TargetAccountID) - if err != nil { - log.Errorf(ctx, "error getting follow target account %q: %v", follow.TargetAccountID, err) - } - - return follow, nil +func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) { + n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx) + return n, r.conn.ProcessError(err) } -func (r *relationshipDB) GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, db.Error) { - accountIDs := []string{} - - // Select only the account ID of each follow. - q := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). - ColumnExpr("? AS ?", bun.Ident("follow.account_id"), bun.Ident("account_id")). - Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID) - - // Join on accounts table to select only - // those with NULL domain (local accounts). - q = q. - Join("JOIN ? AS ? ON ? = ?", - bun.Ident("accounts"), - bun.Ident("account"), - bun.Ident("follow.account_id"), - bun.Ident("account.id"), - ). - Where("? IS NULL", bun.Ident("account.domain")) +func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { + n, err := newSelectFollowers(r.conn, accountID).Count(ctx) + return n, r.conn.ProcessError(err) +} - // We don't *really* need to order these, - // but it makes it more consistent to do so. - q = q.Order("account_id DESC") +func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) { + n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx) + return n, r.conn.ProcessError(err) +} - if err := q.Scan(ctx, &accountIDs); err != nil { +func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { + var followReqIDs []string + if err := newSelectFollowRequests(r.conn, accountID). + Scan(ctx, &followReqIDs); err != nil { return nil, r.conn.ProcessError(err) } - - return accountIDs, nil + return r.GetFollowRequestsByIDs(ctx, followReqIDs) } -func (r *relationshipDB) GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, db.Error) { - ids := []string{} - - q := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). - Column("follow.id"). - Order("follow.updated_at DESC") - - if accountID != "" { - q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID) - } - - if targetAccountID != "" { - q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID) - } - - if err := q.Scan(ctx, &ids); err != nil { +func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { + var followReqIDs []string + if err := newSelectFollowRequesting(r.conn, accountID). + Scan(ctx, &followReqIDs); err != nil { return nil, r.conn.ProcessError(err) } - - follows := make([]*gtsmodel.Follow, 0, len(ids)) - for _, id := range ids { - follow, err := r.getFollow(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting follow %q: %v", id, err) - continue - } - - follows = append(follows, follow) - } - - return follows, nil + return r.GetFollowRequestsByIDs(ctx, followReqIDs) } -func (r *relationshipDB) CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) { - q := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). - Column("follow.id") - - if accountID != "" { - q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID) - } - - if targetAccountID != "" { - q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID) - } - - return q.Count(ctx) +func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { + n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx) + return n, r.conn.ProcessError(err) } -func (r *relationshipDB) getFollowRequest(ctx context.Context, id string) (*gtsmodel.FollowRequest, db.Error) { - followRequest := >smodel.FollowRequest{} - - err := r.conn. - NewSelect(). - Model(followRequest). - Where("? = ?", bun.Ident("follow_request.id"), id). - Scan(ctx) - if err != nil { - return nil, r.conn.ProcessError(err) - } - - followRequest.Account, err = r.state.DB.GetAccountByID(ctx, followRequest.AccountID) - if err != nil { - log.Errorf(ctx, "error getting follow request account %q: %v", followRequest.AccountID, err) - } - - followRequest.TargetAccount, err = r.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID) - if err != nil { - log.Errorf(ctx, "error getting follow request target account %q: %v", followRequest.TargetAccountID, err) - } - - return followRequest, nil +func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { + n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx) + return n, r.conn.ProcessError(err) } -func (r *relationshipDB) GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, db.Error) { - ids := []string{} - - q := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). - Column("follow_request.id") - - if accountID != "" { - q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID) - } - - if targetAccountID != "" { - q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID) - } - - if err := q.Scan(ctx, &ids); err != nil { - return nil, r.conn.ProcessError(err) - } - - followRequests := make([]*gtsmodel.FollowRequest, 0, len(ids)) - for _, id := range ids { - followRequest, err := r.getFollowRequest(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting follow request %q: %v", id, err) - continue - } - - followRequests = append(followRequests, followRequest) - } - - return followRequests, nil +// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. +func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery { + return conn.NewSelect(). + TableExpr("?", bun.Ident("follow_requests")). + ColumnExpr("?", bun.Ident("id")). + Where("? = ?", bun.Ident("target_account_id"), accountID). + OrderExpr("? DESC", bun.Ident("updated_at")) } -func (r *relationshipDB) CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) { - q := r.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). - Column("follow_request.id"). - Order("follow_request.updated_at DESC") - - if accountID != "" { - q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID) - } - - if targetAccountID != "" { - q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID) - } - - return q.Count(ctx) +// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. +func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery { + return conn.NewSelect(). + TableExpr("?", bun.Ident("follow_requests")). + ColumnExpr("?", bun.Ident("id")). + Where("? = ?", bun.Ident("target_account_id"), accountID). + OrderExpr("? DESC", bun.Ident("updated_at")) } -func (r *relationshipDB) Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) { - uri := new(string) - - _, err := r.conn. - NewDelete(). - TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). - Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID). - Where("? = ?", bun.Ident("follow.account_id"), originAccountID). - Returning("?", bun.Ident("uri")).Exec(ctx, uri) - - // Only return proper errors. - if err = r.conn.ProcessError(err); err != db.ErrNoEntries { - return *uri, err - } - - return *uri, nil +// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. +func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery { + return conn.NewSelect(). + Table("follows"). + Column("id"). + Where("? = ?", bun.Ident("account_id"), accountID). + OrderExpr("? DESC", bun.Ident("updated_at")) } -func (r *relationshipDB) UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) { - uri := new(string) - - _, err := r.conn. - NewDelete(). - TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). - Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). - Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). - Returning("?", bun.Ident("uri")).Exec(ctx, uri) - - // Only return proper errors. - if err = r.conn.ProcessError(err); err != db.ErrNoEntries { - return *uri, err - } - - return *uri, nil +// newSelectLocalFollows returns a new select query for all rows in the follows table with +// account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). +func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery { + return conn.NewSelect(). + Table("follows"). + Column("id"). + Where("? = ? AND ? IN (?)", + bun.Ident("account_id"), + accountID, + bun.Ident("target_account_id"), + conn.NewSelect(). + Table("accounts"). + Column("id"). + Where("? IS NULL", bun.Ident("domain")), + ). + OrderExpr("? DESC", bun.Ident("updated_at")) +} + +// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. +func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery { + return conn.NewSelect(). + Table("follows"). + Column("id"). + Where("? = ?", bun.Ident("target_account_id"), accountID). + OrderExpr("? DESC", bun.Ident("updated_at")) +} + +// newSelectLocalFollowers returns a new select query for all rows in the follows table with +// target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). +func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery { + return conn.NewSelect(). + Table("follows"). + Column("id"). + Where("? = ? AND ? IN (?)", + bun.Ident("target_account_id"), + accountID, + bun.Ident("account_id"), + conn.NewSelect(). + Table("accounts"). + Column("id"). + Where("? IS NULL", bun.Ident("domain")), + ). + OrderExpr("? DESC", bun.Ident("updated_at")) } diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go new file mode 100644 index 000000000..9232ea984 --- /dev/null +++ b/internal/db/bundb/relationship_block.go @@ -0,0 +1,218 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "errors" + "fmt" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/uptrace/bun" +) + +func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { + block, err := r.GetBlock( + gtscontext.SetBarebones(ctx), + sourceAccountID, + targetAccountID, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return false, err + } + return (block != nil), nil +} + +func (r *relationshipDB) IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error) { + // Look for a block in direction of account1->account2 + b1, err := r.IsBlocked(ctx, accountID1, accountID2) + if err != nil || b1 { + return true, err + } + + // Look for a block in direction of account2->account1 + b2, err := r.IsBlocked(ctx, accountID2, accountID1) + if err != nil || b2 { + return true, err + } + + return false, nil +} + +func (r *relationshipDB) GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error) { + return r.getBlock( + ctx, + "ID", + func(block *gtsmodel.Block) error { + return r.conn.NewSelect().Model(block). + Where("? = ?", bun.Ident("block.id"), id). + Scan(ctx) + }, + id, + ) +} + +func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error) { + return r.getBlock( + ctx, + "URI", + func(block *gtsmodel.Block) error { + return r.conn.NewSelect().Model(block). + Where("? = ?", bun.Ident("block.uri"), uri). + Scan(ctx) + }, + uri, + ) +} + +func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) { + return r.getBlock( + ctx, + "AccountID.TargetAccountID", + func(block *gtsmodel.Block) error { + return r.conn.NewSelect().Model(block). + Where("? = ?", bun.Ident("block.account_id"), sourceAccountID). + Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID). + Scan(ctx) + }, + sourceAccountID, + targetAccountID, + ) +} + +func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) { + // Fetch block from cache with loader callback + block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) { + var block gtsmodel.Block + + // Not cached! Perform database query + if err := dbQuery(&block); err != nil { + return nil, r.conn.ProcessError(err) + } + + return &block, nil + }, keyParts...) + if err != nil { + // already processe + return nil, err + } + + if gtscontext.Barebones(ctx) { + // Only a barebones model was requested. + return block, nil + } + + // Set the block source account + block.Account, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + block.AccountID, + ) + if err != nil { + return nil, fmt.Errorf("error getting block source account: %w", err) + } + + // Set the block target account + block.TargetAccount, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + block.TargetAccountID, + ) + if err != nil { + return nil, fmt.Errorf("error getting block target account: %w", err) + } + + return block, nil +} + +func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error { + err := r.state.Caches.GTS.Block().Store(block, func() error { + _, err := r.conn.NewInsert().Model(block).Exec(ctx) + return r.conn.ProcessError(err) + }) + if err != nil { + return err + } + + // Invalidate block origin account ID cached visibility. + r.state.Caches.Visibility.Invalidate("ItemID", block.AccountID) + r.state.Caches.Visibility.Invalidate("RequesterID", block.AccountID) + + // Invalidate block target account ID cached visibility. + r.state.Caches.Visibility.Invalidate("ItemID", block.TargetAccountID) + r.state.Caches.Visibility.Invalidate("RequesterID", block.TargetAccountID) + + return nil +} + +func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error { + block, err := r.GetBlockByID(gtscontext.SetBarebones(ctx), id) + if err != nil { + return err + } + return r.deleteBlock(ctx, block) +} + +func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error { + block, err := r.GetBlockByURI(gtscontext.SetBarebones(ctx), uri) + if err != nil { + return err + } + return r.deleteBlock(ctx, block) +} + +func (r *relationshipDB) deleteBlock(ctx context.Context, block *gtsmodel.Block) error { + if _, err := r.conn. + NewDelete(). + Table("blocks"). + Where("? = ?", bun.Ident("id"), block.ID). + Exec(ctx); err != nil { + return r.conn.ProcessError(err) + } + + // Invalidate block from cache lookups. + r.state.Caches.GTS.Block().Invalidate("ID", block.ID) + + return nil +} + +func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error { + var blockIDs []string + + if err := r.conn.NewSelect(). + Table("blocks"). + ColumnExpr("?", bun.Ident("id")). + WhereOr("? = ? OR ? = ?", + bun.Ident("account_id"), + accountID, + bun.Ident("target_account_id"), + accountID, + ). + Scan(ctx, &blockIDs); err != nil { + return r.conn.ProcessError(err) + } + + for _, id := range blockIDs { + if err := r.DeleteBlockByID(ctx, id); err != nil { + log.Errorf(ctx, "error deleting block %q: %v", id, err) + } + } + + return nil +} diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go new file mode 100644 index 000000000..4a315d116 --- /dev/null +++ b/internal/db/bundb/relationship_follow.go @@ -0,0 +1,243 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "errors" + "fmt" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/uptrace/bun" +) + +func (r *relationshipDB) GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error) { + return r.getFollow( + ctx, + "ID", + func(follow *gtsmodel.Follow) error { + return r.conn.NewSelect(). + Model(follow). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + }, + id, + ) +} + +func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error) { + return r.getFollow( + ctx, + "URI", + func(follow *gtsmodel.Follow) error { + return r.conn.NewSelect(). + Model(follow). + Where("? = ?", bun.Ident("uri"), uri). + Scan(ctx) + }, + uri, + ) +} + +func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) { + return r.getFollow( + ctx, + "AccountID.TargetAccountID", + func(follow *gtsmodel.Follow) error { + return r.conn.NewSelect(). + Model(follow). + Where("? = ?", bun.Ident("account_id"), sourceAccountID). + Where("? = ?", bun.Ident("target_account_id"), targetAccountID). + Scan(ctx) + }, + sourceAccountID, + targetAccountID, + ) +} + +func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) { + // Preallocate slice of expected length. + follows := make([]*gtsmodel.Follow, 0, len(ids)) + + for _, id := range ids { + // Fetch follow model for this ID. + follow, err := r.GetFollowByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error getting follow %q: %v", id, err) + continue + } + + // Append to return slice. + follows = append(follows, follow) + } + + return follows, nil +} + +func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { + follow, err := r.GetFollow( + gtscontext.SetBarebones(ctx), + sourceAccountID, + targetAccountID, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return false, err + } + return (follow != nil), nil +} + +func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, db.Error) { + // make sure account 1 follows account 2 + f1, err := r.IsFollowing(ctx, + accountID1, + accountID2, + ) + if !f1 /* f1 = false when err != nil */ { + return false, err + } + + // make sure account 2 follows account 1 + f2, err := r.IsFollowing(ctx, + accountID2, + accountID1, + ) + if !f2 /* f2 = false when err != nil */ { + return false, err + } + + return true, nil +} + +func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) { + // Fetch follow from database cache with loader callback + follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) { + var follow gtsmodel.Follow + + // Not cached! Perform database query + if err := dbQuery(&follow); err != nil { + return nil, r.conn.ProcessError(err) + } + + return &follow, nil + }, keyParts...) + if err != nil { + // error already processed + return nil, err + } + + if gtscontext.Barebones(ctx) { + // Only a barebones model was requested. + return follow, nil + } + + // Set the follow source account + follow.Account, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + follow.AccountID, + ) + if err != nil { + return nil, fmt.Errorf("error getting follow source account: %w", err) + } + + // Set the follow target account + follow.TargetAccount, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + follow.TargetAccountID, + ) + if err != nil { + return nil, fmt.Errorf("error getting follow target account: %w", err) + } + + return follow, nil +} + +func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { + err := r.state.Caches.GTS.Follow().Store(follow, func() error { + _, err := r.conn.NewInsert().Model(follow).Exec(ctx) + return r.conn.ProcessError(err) + }) + if err != nil { + return err + } + + // Invalidate follow origin account ID cached visibility. + r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID) + r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID) + + // Invalidate follow target account ID cached visibility. + r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID) + r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID) + + return nil +} + +func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error { + if _, err := r.conn.NewDelete(). + Table("follows"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return r.conn.ProcessError(err) + } + + // Invalidate follow from cache lookups. + r.state.Caches.GTS.Follow().Invalidate("ID", id) + + return nil +} + +func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error { + if _, err := r.conn.NewDelete(). + Table("follows"). + Where("? = ?", bun.Ident("uri"), uri). + Exec(ctx); err != nil { + return r.conn.ProcessError(err) + } + + // Invalidate follow from cache lookups. + r.state.Caches.GTS.Follow().Invalidate("URI", uri) + + return nil +} + +func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error { + var followIDs []string + + if _, err := r.conn. + NewDelete(). + Table("follows"). + WhereOr("? = ? OR ? = ?", + bun.Ident("account_id"), + accountID, + bun.Ident("target_account_id"), + accountID, + ). + Returning("?", bun.Ident("id")). + Exec(ctx, &followIDs); err != nil { + return r.conn.ProcessError(err) + } + + // Invalidate each returned ID. + for _, id := range followIDs { + r.state.Caches.GTS.Follow().Invalidate("ID", id) + } + + return nil +} diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go new file mode 100644 index 000000000..11200338d --- /dev/null +++ b/internal/db/bundb/relationship_follow_req.go @@ -0,0 +1,293 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "errors" + "fmt" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/uptrace/bun" +) + +func (r *relationshipDB) GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error) { + return r.getFollowRequest( + ctx, + "ID", + func(followReq *gtsmodel.FollowRequest) error { + return r.conn.NewSelect(). + Model(followReq). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + }, + id, + ) +} + +func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error) { + return r.getFollowRequest( + ctx, + "URI", + func(followReq *gtsmodel.FollowRequest) error { + return r.conn.NewSelect(). + Model(followReq). + Where("? = ?", bun.Ident("uri"), uri). + Scan(ctx) + }, + uri, + ) +} + +func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) { + return r.getFollowRequest( + ctx, + "AccountID.TargetAccountID", + func(followReq *gtsmodel.FollowRequest) error { + return r.conn.NewSelect(). + Model(followReq). + Where("? = ?", bun.Ident("account_id"), sourceAccountID). + Where("? = ?", bun.Ident("target_account_id"), targetAccountID). + Scan(ctx) + }, + sourceAccountID, + targetAccountID, + ) +} + +func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) { + // Preallocate slice of expected length. + followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids)) + + for _, id := range ids { + // Fetch follow request model for this ID. + followReq, err := r.GetFollowRequestByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error getting follow request %q: %v", id, err) + continue + } + + // Append to return slice. + followReqs = append(followReqs, followReq) + } + + return followReqs, nil +} + +func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { + followReq, err := r.GetFollowRequest( + gtscontext.SetBarebones(ctx), + sourceAccountID, + targetAccountID, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return false, err + } + return (followReq != nil), nil +} + +func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) { + // Fetch follow request from database cache with loader callback + followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) { + var followReq gtsmodel.FollowRequest + + // Not cached! Perform database query + if err := dbQuery(&followReq); err != nil { + return nil, r.conn.ProcessError(err) + } + + return &followReq, nil + }, keyParts...) + if err != nil { + // error already processed + return nil, err + } + + if gtscontext.Barebones(ctx) { + // Only a barebones model was requested. + return followReq, nil + } + + // Set the follow request source account + followReq.Account, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + followReq.AccountID, + ) + if err != nil { + return nil, fmt.Errorf("error getting follow request source account: %w", err) + } + + // Set the follow request target account + followReq.TargetAccount, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + followReq.TargetAccountID, + ) + if err != nil { + return nil, fmt.Errorf("error getting follow request target account: %w", err) + } + + return followReq, nil +} + +func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error { + err := r.state.Caches.GTS.FollowRequest().Store(follow, func() error { + _, err := r.conn.NewInsert().Model(follow).Exec(ctx) + return r.conn.ProcessError(err) + }) + if err != nil { + return err + } + + // Invalidate follow request origin account ID cached visibility. + r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID) + r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID) + + // Invalidate follow request target account ID cached visibility. + r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID) + r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID) + + return nil +} + +func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { + // Get original follow request. + followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID) + if err != nil { + return nil, err + } + + // Create a new follow to 'replace' + // the original follow request with. + follow := >smodel.Follow{ + ID: followReq.ID, + AccountID: sourceAccountID, + Account: followReq.Account, + TargetAccountID: targetAccountID, + TargetAccount: followReq.TargetAccount, + URI: followReq.URI, + } + + // If the follow already exists, just + // replace the URI with the new one. + if _, err := r.conn. + NewInsert(). + Model(follow). + On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). + Exec(ctx); err != nil { + return nil, r.conn.ProcessError(err) + } + + // Delete original follow request. + if _, err := r.conn. + NewDelete(). + Table("follow_requests"). + Where("? = ?", bun.Ident("id"), followReq.ID). + Exec(ctx); err != nil { + return nil, r.conn.ProcessError(err) + } + + // Invalidate follow request from cache lookups. + r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID) + + // Delete original follow request notification + if err := r.state.DB.DeleteNotifications(ctx, []string{ + string(gtsmodel.NotificationFollowRequest), + }, targetAccountID, sourceAccountID); err != nil { + return nil, err + } + + return follow, nil +} + +func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) db.Error { + // Get original follow request. + followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID) + if err != nil { + return err + } + + // Delete original follow request. + if _, err := r.conn. + NewDelete(). + Table("follow_requests"). + Where("? = ?", bun.Ident("id"), followReq.ID). + Exec(ctx); err != nil { + return r.conn.ProcessError(err) + } + + // Delete original follow request notification + return r.state.DB.DeleteNotifications(ctx, []string{ + string(gtsmodel.NotificationFollowRequest), + }, targetAccountID, sourceAccountID) +} + +func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error { + if _, err := r.conn.NewDelete(). + Table("follow_requests"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return r.conn.ProcessError(err) + } + + // Invalidate follow request from cache lookups. + r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) + + return nil +} + +func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error { + if _, err := r.conn.NewDelete(). + Table("follow_requests"). + Where("? = ?", bun.Ident("uri"), uri). + Exec(ctx); err != nil { + return r.conn.ProcessError(err) + } + + // Invalidate follow request from cache lookups. + r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri) + + return nil +} + +func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error { + var followIDs []string + + if _, err := r.conn. + NewDelete(). + Table("follow_requests"). + WhereOr("? = ? OR ? = ?", + bun.Ident("account_id"), + accountID, + bun.Ident("target_account_id"), + accountID, + ). + Returning("?", bun.Ident("id")). + Exec(ctx, &followIDs); err != nil { + return r.conn.ProcessError(err) + } + + // Invalidate each returned ID. + for _, id := range followIDs { + r.state.Caches.GTS.FollowRequest().Invalidate("ID", id) + } + + return nil +} diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go index 3d307ecde..00583d175 100644 --- a/internal/db/bundb/relationship_test.go +++ b/internal/db/bundb/relationship_test.go @@ -19,17 +19,359 @@ package bundb_test import ( "context" + "errors" + "reflect" "testing" + "time" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" ) type RelationshipTestSuite struct { BunDBStandardTestSuite } +func (suite *RelationshipTestSuite) TestGetBlockBy() { + t := suite.T() + + // Create a new context for this test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Sentinel error to mark avoiding a test case. + sentinelErr := errors.New("sentinel") + + // isEqual checks if 2 block models are equal. + isEqual := func(b1, b2 gtsmodel.Block) bool { + // Clear populated sub-models. + b1.Account = nil + b2.Account = nil + b1.TargetAccount = nil + b2.TargetAccount = nil + + // Clear database-set fields. + b1.CreatedAt = time.Time{} + b2.CreatedAt = time.Time{} + b1.UpdatedAt = time.Time{} + b2.UpdatedAt = time.Time{} + + return reflect.DeepEqual(b1, b2) + } + + var testBlocks []*gtsmodel.Block + + for _, account1 := range suite.testAccounts { + for _, account2 := range suite.testAccounts { + if account1.ID == account2.ID { + // don't block *yourself* ... + continue + } + + // Create new account block. + block := >smodel.Block{ + ID: id.NewULID(), + URI: "http://127.0.0.1:8080/" + id.NewULID(), + AccountID: account1.ID, + TargetAccountID: account2.ID, + } + + // Attempt to place the block in database (if not already). + if err := suite.db.PutBlock(ctx, block); err != nil { + if err != db.ErrAlreadyExists { + // Unrecoverable database error. + t.Fatalf("error creating block: %v", err) + } + + // Fetch existing block from database between accounts. + block, _ = suite.db.GetBlock(ctx, account1.ID, account2.ID) + continue + } + + // Append generated block to test cases. + testBlocks = append(testBlocks, block) + } + } + + for _, block := range testBlocks { + for lookup, dbfunc := range map[string]func() (*gtsmodel.Block, error){ + "id": func() (*gtsmodel.Block, error) { + return suite.db.GetBlockByID(ctx, block.ID) + }, + + "uri": func() (*gtsmodel.Block, error) { + return suite.db.GetBlockByURI(ctx, block.URI) + }, + + "origin_target": func() (*gtsmodel.Block, error) { + return suite.db.GetBlock(ctx, block.AccountID, block.TargetAccountID) + }, + } { + + // Clear database caches. + suite.state.Caches.Init() + + t.Logf("checking database lookup %q", lookup) + + // Perform database function. + checkBlock, err := dbfunc() + if err != nil { + if err == sentinelErr { + continue + } + + t.Errorf("error encountered for database lookup %q: %v", lookup, err) + continue + } + + // Check received block data. + if !isEqual(*checkBlock, *block) { + t.Errorf("block does not contain expected data: %+v", checkBlock) + continue + } + + // Check that block origin account populated. + if checkBlock.Account == nil || checkBlock.Account.ID != block.AccountID { + t.Errorf("block origin account not correctly populated for: %+v", checkBlock) + continue + } + + // Check that block target account populated. + if checkBlock.TargetAccount == nil || checkBlock.TargetAccount.ID != block.TargetAccountID { + t.Errorf("block target account not correctly populated for: %+v", checkBlock) + continue + } + } + } +} + +func (suite *RelationshipTestSuite) TestGetFollowBy() { + t := suite.T() + + // Create a new context for this test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Sentinel error to mark avoiding a test case. + sentinelErr := errors.New("sentinel") + + // isEqual checks if 2 follow models are equal. + isEqual := func(f1, f2 gtsmodel.Follow) bool { + // Clear populated sub-models. + f1.Account = nil + f2.Account = nil + f1.TargetAccount = nil + f2.TargetAccount = nil + + // Clear database-set fields. + f1.CreatedAt = time.Time{} + f2.CreatedAt = time.Time{} + f1.UpdatedAt = time.Time{} + f2.UpdatedAt = time.Time{} + + return reflect.DeepEqual(f1, f2) + } + + var testFollows []*gtsmodel.Follow + + for _, account1 := range suite.testAccounts { + for _, account2 := range suite.testAccounts { + if account1.ID == account2.ID { + // don't follow *yourself* ... + continue + } + + // Create new account follow. + follow := >smodel.Follow{ + ID: id.NewULID(), + URI: "http://127.0.0.1:8080/" + id.NewULID(), + AccountID: account1.ID, + TargetAccountID: account2.ID, + } + + // Attempt to place the follow in database (if not already). + if err := suite.db.PutFollow(ctx, follow); err != nil { + if err != db.ErrAlreadyExists { + // Unrecoverable database error. + t.Fatalf("error creating follow: %v", err) + } + + // Fetch existing follow from database between accounts. + follow, _ = suite.db.GetFollow(ctx, account1.ID, account2.ID) + continue + } + + // Append generated follow to test cases. + testFollows = append(testFollows, follow) + } + } + + for _, follow := range testFollows { + for lookup, dbfunc := range map[string]func() (*gtsmodel.Follow, error){ + "id": func() (*gtsmodel.Follow, error) { + return suite.db.GetFollowByID(ctx, follow.ID) + }, + + "uri": func() (*gtsmodel.Follow, error) { + return suite.db.GetFollowByURI(ctx, follow.URI) + }, + + "origin_target": func() (*gtsmodel.Follow, error) { + return suite.db.GetFollow(ctx, follow.AccountID, follow.TargetAccountID) + }, + } { + // Clear database caches. + suite.state.Caches.Init() + + t.Logf("checking database lookup %q", lookup) + + // Perform database function. + checkFollow, err := dbfunc() + if err != nil { + if err == sentinelErr { + continue + } + + t.Errorf("error encountered for database lookup %q: %v", lookup, err) + continue + } + + // Check received follow data. + if !isEqual(*checkFollow, *follow) { + t.Errorf("follow does not contain expected data: %+v", checkFollow) + continue + } + + // Check that follow origin account populated. + if checkFollow.Account == nil || checkFollow.Account.ID != follow.AccountID { + t.Errorf("follow origin account not correctly populated for: %+v", checkFollow) + continue + } + + // Check that follow target account populated. + if checkFollow.TargetAccount == nil || checkFollow.TargetAccount.ID != follow.TargetAccountID { + t.Errorf("follow target account not correctly populated for: %+v", checkFollow) + continue + } + } + } +} + +func (suite *RelationshipTestSuite) TestGetFollowRequestBy() { + t := suite.T() + + // Create a new context for this test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Sentinel error to mark avoiding a test case. + sentinelErr := errors.New("sentinel") + + // isEqual checks if 2 follow request models are equal. + isEqual := func(f1, f2 gtsmodel.FollowRequest) bool { + // Clear populated sub-models. + f1.Account = nil + f2.Account = nil + f1.TargetAccount = nil + f2.TargetAccount = nil + + // Clear database-set fields. + f1.CreatedAt = time.Time{} + f2.CreatedAt = time.Time{} + f1.UpdatedAt = time.Time{} + f2.UpdatedAt = time.Time{} + + return reflect.DeepEqual(f1, f2) + } + + var testFollowReqs []*gtsmodel.FollowRequest + + for _, account1 := range suite.testAccounts { + for _, account2 := range suite.testAccounts { + if account1.ID == account2.ID { + // don't follow *yourself* ... + continue + } + + // Create new account follow request. + followReq := >smodel.FollowRequest{ + ID: id.NewULID(), + URI: "http://127.0.0.1:8080/" + id.NewULID(), + AccountID: account1.ID, + TargetAccountID: account2.ID, + } + + // Attempt to place the follow in database (if not already). + if err := suite.db.PutFollowRequest(ctx, followReq); err != nil { + if err != db.ErrAlreadyExists { + // Unrecoverable database error. + t.Fatalf("error creating follow request: %v", err) + } + + // Fetch existing follow request from database between accounts. + followReq, _ = suite.db.GetFollowRequest(ctx, account1.ID, account2.ID) + continue + } + + // Append generated follow request to test cases. + testFollowReqs = append(testFollowReqs, followReq) + } + } + + for _, followReq := range testFollowReqs { + for lookup, dbfunc := range map[string]func() (*gtsmodel.FollowRequest, error){ + "id": func() (*gtsmodel.FollowRequest, error) { + return suite.db.GetFollowRequestByID(ctx, followReq.ID) + }, + + "uri": func() (*gtsmodel.FollowRequest, error) { + return suite.db.GetFollowRequestByURI(ctx, followReq.URI) + }, + + "origin_target": func() (*gtsmodel.FollowRequest, error) { + return suite.db.GetFollowRequest(ctx, followReq.AccountID, followReq.TargetAccountID) + }, + } { + + // Clear database caches. + suite.state.Caches.Init() + + t.Logf("checking database lookup %q", lookup) + + // Perform database function. + checkFollowReq, err := dbfunc() + if err != nil { + if err == sentinelErr { + continue + } + + t.Errorf("error encountered for database lookup %q: %v", lookup, err) + continue + } + + // Check received follow request data. + if !isEqual(*checkFollowReq, *followReq) { + t.Errorf("follow request does not contain expected data: %+v", checkFollowReq) + continue + } + + // Check that follow request origin account populated. + if checkFollowReq.Account == nil || checkFollowReq.Account.ID != followReq.AccountID { + t.Errorf("follow request origin account not correctly populated for: %+v", checkFollowReq) + continue + } + + // Check that follow request target account populated. + if checkFollowReq.TargetAccount == nil || checkFollowReq.TargetAccount.ID != followReq.TargetAccountID { + t.Errorf("follow request target account not correctly populated for: %+v", checkFollowReq) + continue + } + } + } +} + func (suite *RelationshipTestSuite) TestIsBlocked() { ctx := context.Background() @@ -37,11 +379,11 @@ func (suite *RelationshipTestSuite) TestIsBlocked() { account2 := suite.testAccounts["local_account_2"].ID // no blocks exist between account 1 and account 2 - blocked, err := suite.db.IsBlocked(ctx, account1, account2, false) + blocked, err := suite.db.IsBlocked(ctx, account1, account2) suite.NoError(err) suite.False(blocked) - blocked, err = suite.db.IsBlocked(ctx, account2, account1, false) + blocked, err = suite.db.IsBlocked(ctx, account2, account1) suite.NoError(err) suite.False(blocked) @@ -56,45 +398,24 @@ func (suite *RelationshipTestSuite) TestIsBlocked() { } // account 1 now blocks account 2 - blocked, err = suite.db.IsBlocked(ctx, account1, account2, false) + blocked, err = suite.db.IsBlocked(ctx, account1, account2) suite.NoError(err) suite.True(blocked) // account 2 doesn't block account 1 - blocked, err = suite.db.IsBlocked(ctx, account2, account1, false) + blocked, err = suite.db.IsBlocked(ctx, account2, account1) suite.NoError(err) suite.False(blocked) // a block exists in either direction between the two - blocked, err = suite.db.IsBlocked(ctx, account1, account2, true) + blocked, err = suite.db.IsEitherBlocked(ctx, account1, account2) suite.NoError(err) suite.True(blocked) - blocked, err = suite.db.IsBlocked(ctx, account2, account1, true) + blocked, err = suite.db.IsEitherBlocked(ctx, account2, account1) suite.NoError(err) suite.True(blocked) } -func (suite *RelationshipTestSuite) TestGetBlock() { - ctx := context.Background() - - account1 := suite.testAccounts["local_account_1"].ID - account2 := suite.testAccounts["local_account_2"].ID - - if err := suite.db.PutBlock(ctx, >smodel.Block{ - ID: "01G202BCSXXJZ70BHB5KCAHH8C", - URI: "http://localhost:8080/some_block_uri_1", - AccountID: account1, - TargetAccountID: account2, - }); err != nil { - suite.FailNow(err.Error()) - } - - block, err := suite.db.GetBlock(ctx, account1, account2) - suite.NoError(err) - suite.NotNil(block) - suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID) -} - func (suite *RelationshipTestSuite) TestDeleteBlockByID() { ctx := context.Background() @@ -157,7 +478,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlockByURI() { suite.Nil(block) } -func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() { +func (suite *RelationshipTestSuite) TestDeleteAccountBlocks() { ctx := context.Background() // put a block in first @@ -179,38 +500,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() { suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID) // delete the block by originAccountID - err = suite.db.DeleteBlocksByOriginAccountID(ctx, account1) - suite.NoError(err) - - // block should be gone - block, err = suite.db.GetBlock(ctx, account1, account2) - suite.ErrorIs(err, db.ErrNoEntries) - suite.Nil(block) -} - -func (suite *RelationshipTestSuite) TestDeleteBlocksByTargetAccountID() { - ctx := context.Background() - - // put a block in first - account1 := suite.testAccounts["local_account_1"].ID - account2 := suite.testAccounts["local_account_2"].ID - if err := suite.db.PutBlock(ctx, >smodel.Block{ - ID: "01G202BCSXXJZ70BHB5KCAHH8C", - URI: "http://localhost:8080/some_block_uri_1", - AccountID: account1, - TargetAccountID: account2, - }); err != nil { - suite.FailNow(err.Error()) - } - - // make sure the block is in the db - block, err := suite.db.GetBlock(ctx, account1, account2) - suite.NoError(err) - suite.NotNil(block) - suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID) - - // delete the block by targetAccountID - err = suite.db.DeleteBlocksByTargetAccountID(ctx, account2) + err = suite.db.DeleteAccountBlocks(ctx, account1) suite.NoError(err) // block should be gone @@ -244,7 +534,7 @@ func (suite *RelationshipTestSuite) TestGetRelationship() { func (suite *RelationshipTestSuite) TestIsFollowingYes() { requestingAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["admin_account"] - isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount) + isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID) suite.NoError(err) suite.True(isFollowing) } @@ -252,7 +542,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingYes() { func (suite *RelationshipTestSuite) TestIsFollowingNo() { requestingAccount := suite.testAccounts["admin_account"] targetAccount := suite.testAccounts["local_account_2"] - isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount) + isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID) suite.NoError(err) suite.False(isFollowing) } @@ -260,7 +550,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingNo() { func (suite *RelationshipTestSuite) TestIsMutualFollowing() { requestingAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["admin_account"] - isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount) + isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID) suite.NoError(err) suite.True(isMutualFollowing) } @@ -268,7 +558,7 @@ func (suite *RelationshipTestSuite) TestIsMutualFollowing() { func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() { requestingAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["local_account_2"] - isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount) + isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID) suite.NoError(err) suite.True(isMutualFollowing) } @@ -306,7 +596,7 @@ func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() { suite.Equal(followRequest.URI, follow.URI) // Ensure notification is deleted. - notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID) + notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID) suite.ErrorIs(err, db.ErrNoEntries) suite.Nil(notification) } @@ -389,7 +679,7 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() { TargetAccountID: targetAccount.ID, } - if err := suite.db.Put(ctx, followRequest); err != nil { + if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil { suite.FailNow(err.Error()) } @@ -404,12 +694,11 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() { suite.FailNow(err.Error()) } - rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) + err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) suite.NoError(err) - suite.NotNil(rejectedFollowRequest) // Ensure notification is deleted. - notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID) + notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID) suite.ErrorIs(err, db.ErrNoEntries) suite.Nil(notification) } @@ -419,9 +708,8 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() { account := suite.testAccounts["admin_account"] targetAccount := suite.testAccounts["local_account_2"] - rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) + err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) suite.ErrorIs(err, db.ErrNoEntries) - suite.Nil(rejectedFollowRequest) } func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() { @@ -440,42 +728,49 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() { suite.FailNow(err.Error()) } - followRequests, err := suite.db.GetFollowRequests(ctx, "", targetAccount.ID) + followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID) suite.NoError(err) suite.Len(followRequests, 1) } func (suite *RelationshipTestSuite) TestGetAccountFollows() { account := suite.testAccounts["local_account_1"] - follows, err := suite.db.GetFollows(context.Background(), account.ID, "") + follows, err := suite.db.GetAccountFollows(context.Background(), account.ID) suite.NoError(err) suite.Len(follows, 2) } +func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() { + account := suite.testAccounts["local_account_1"] + followsCount, err := suite.db.CountAccountLocalFollows(context.Background(), account.ID) + suite.NoError(err) + suite.Equal(2, followsCount) +} + func (suite *RelationshipTestSuite) TestCountAccountFollows() { account := suite.testAccounts["local_account_1"] - followsCount, err := suite.db.CountFollows(context.Background(), account.ID, "") + followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID) suite.NoError(err) suite.Equal(2, followsCount) } -func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() { +func (suite *RelationshipTestSuite) TestGetAccountFollowers() { account := suite.testAccounts["local_account_1"] - follows, err := suite.db.GetFollows(context.Background(), "", account.ID) + follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID) suite.NoError(err) suite.Len(follows, 2) } -func (suite *RelationshipTestSuite) TestGetLocalFollowersIDs() { +func (suite *RelationshipTestSuite) TestCountAccountFollowers() { account := suite.testAccounts["local_account_1"] - accountIDs, err := suite.db.GetLocalFollowersIDs(context.Background(), account.ID) + followsCount, err := suite.db.CountAccountFollowers(context.Background(), account.ID) suite.NoError(err) - suite.EqualValues([]string{"01F8MH5NBDF2MV7CTC4Q5128HF", "01F8MH17FWEB39HZJ76B6VXSKF"}, accountIDs) + suite.Equal(2, followsCount) } -func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() { +func (suite *RelationshipTestSuite) TestCountAccountFollowersLocalOnly() { account := suite.testAccounts["local_account_1"] - followsCount, err := suite.db.CountFollows(context.Background(), "", account.ID) + followsCount, err := suite.db.CountAccountLocalFollowers(context.Background(), account.ID) suite.NoError(err) suite.Equal(2, followsCount) } @@ -484,18 +779,25 @@ func (suite *RelationshipTestSuite) TestUnfollowExisting() { originAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["admin_account"] - uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccount.ID) + follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID) suite.NoError(err) - suite.Equal("http://localhost:8080/users/the_mighty_zork/follow/01F8PY8RHWRQZV038T4E8T9YK8", uri) + suite.NotNil(follow) + + err = suite.db.DeleteFollowByID(context.Background(), follow.ID) + suite.NoError(err) + + follow, err = suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID) + suite.EqualError(err, db.ErrNoEntries.Error()) + suite.Nil(follow) } func (suite *RelationshipTestSuite) TestUnfollowNotExisting() { originAccount := suite.testAccounts["local_account_1"] targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ" - uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccountID) - suite.NoError(err) - suite.Empty(uri) + follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccountID) + suite.EqualError(err, db.ErrNoEntries.Error()) + suite.Nil(follow) } func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() { @@ -510,22 +812,29 @@ func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() { TargetAccountID: targetAccount.ID, } - if err := suite.db.Put(ctx, followRequest); err != nil { + if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil { suite.FailNow(err.Error()) } - uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccount.ID) + followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID) + suite.NoError(err) + suite.NotNil(followRequest) + + err = suite.db.DeleteFollowRequestByID(context.Background(), followRequest.ID) suite.NoError(err) - suite.Equal("http://localhost:8080/weeeeeeeeeeeeeeeee", uri) + + followRequest, err = suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID) + suite.EqualError(err, db.ErrNoEntries.Error()) + suite.Nil(followRequest) } func (suite *RelationshipTestSuite) TestUnfollowRequestNotExisting() { originAccount := suite.testAccounts["local_account_1"] targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ" - uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccountID) - suite.NoError(err) - suite.Empty(uri) + followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccountID) + suite.EqualError(err, db.ErrNoEntries.Error()) + suite.Nil(followRequest) } func TestRelationshipTestSuite(t *testing.T) { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index deec9a118..c2b5546f8 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -26,6 +26,7 @@ import ( "time" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" @@ -41,7 +42,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { return s.conn. NewSelect(). Model(status). - Relation("Attachments"). Relation("Tags"). Relation("CreatedWithApplication") } @@ -102,79 +102,141 @@ func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*g status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) { var status gtsmodel.Status - // Not cached! Perform database query + // Not cached! Perform database query. if err := dbQuery(&status); err != nil { return nil, s.conn.ProcessError(err) } - if status.InReplyToID != "" { - // Also load in-reply-to status - status.InReplyTo = new(gtsmodel.Status) - err := s.conn.NewSelect().Model(status.InReplyTo). - Where("? = ?", bun.Ident("status.id"), status.InReplyToID). - Scan(ctx) + return &status, nil + }, keyParts...) + if err != nil { + return nil, err + } + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return status, nil + } + + // Further populate the status fields where applicable. + if err := s.PopulateStatus(ctx, status); err != nil { + return nil, err + } + + return status, nil +} + +func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) error { + var err error + + if status.Account == nil { + // Status author is not set, fetch from database. + status.Account, err = s.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + status.AccountID, + ) + if err != nil { + return fmt.Errorf("error populating status author: %w", err) + } + } + + if status.InReplyToID != "" && status.InReplyTo == nil { + // Status parent is not set, fetch from database. + status.InReplyTo, err = s.GetStatusByID( + gtscontext.SetBarebones(ctx), + status.InReplyToID, + ) + if err != nil { + return fmt.Errorf("error populating status parent: %w", err) + } + } + + if status.InReplyToID != "" { + if status.InReplyTo == nil { + // Status parent is not set, fetch from database. + status.InReplyTo, err = s.GetStatusByID( + gtscontext.SetBarebones(ctx), + status.InReplyToID, + ) if err != nil { - return nil, s.conn.ProcessError(err) + return fmt.Errorf("error populating status parent: %w", err) } } - if status.BoostOfID != "" { - // Also load original boosted status - status.BoostOf = new(gtsmodel.Status) - err := s.conn.NewSelect().Model(status.BoostOf). - Where("? = ?", bun.Ident("status.id"), status.BoostOfID). - Scan(ctx) + if status.InReplyToAccount == nil { + // Status parent author is not set, fetch from database. + status.InReplyToAccount, err = s.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + status.InReplyToAccountID, + ) if err != nil { - return nil, s.conn.ProcessError(err) + return fmt.Errorf("error populating status parent author: %w", err) } } - - return &status, nil - }, keyParts...) - if err != nil { - // error already processed - return nil, err } - // Set the status author account - status.Account, err = s.state.DB.GetAccountByID(ctx, status.AccountID) - if err != nil { - return nil, fmt.Errorf("error getting status account: %w", err) - } + if status.BoostOfID != "" { + if status.BoostOf == nil { + // Status boost is not set, fetch from database. + status.BoostOf, err = s.GetStatusByID( + gtscontext.SetBarebones(ctx), + status.BoostOfID, + ) + if err != nil { + return fmt.Errorf("error populating status boost: %w", err) + } + } - if id := status.BoostOfAccountID; id != "" { - // Set boost of status' author account - status.BoostOfAccount, err = s.state.DB.GetAccountByID(ctx, id) - if err != nil { - return nil, fmt.Errorf("error getting boosted status account: %w", err) + if status.BoostOfAccount == nil { + // Status boost author is not set, fetch from database. + status.BoostOfAccount, err = s.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + status.BoostOfAccountID, + ) + if err != nil { + return fmt.Errorf("error populating status boost author: %w", err) + } } } - if id := status.InReplyToAccountID; id != "" { - // Set in-reply-to status' author account - status.InReplyToAccount, err = s.state.DB.GetAccountByID(ctx, id) + if !status.AttachmentsPopulated() { + // Status attachments are out-of-date with IDs, repopulate. + status.Attachments, err = s.state.DB.GetAttachmentsByIDs( + ctx, // these are already barebones + status.AttachmentIDs, + ) if err != nil { - return nil, fmt.Errorf("error getting in reply to status account: %w", err) + return fmt.Errorf("error populating status attachments: %w", err) } } - if len(status.EmojiIDs) > 0 { - // Fetch status emojis - status.Emojis, err = s.state.DB.GetEmojisByIDs(ctx, status.EmojiIDs) + // TODO: once we don't fetch using relations. + // if !status.TagsPopulated() { + // } + + if !status.MentionsPopulated() { + // Status mentions are out-of-date with IDs, repopulate. + status.Mentions, err = s.state.DB.GetMentions( + ctx, // leave fully populated for now + status.MentionIDs, + ) if err != nil { - return nil, fmt.Errorf("error getting status emojis: %w", err) + return fmt.Errorf("error populating status mentions: %w", err) } } - if len(status.MentionIDs) > 0 { - // Fetch status mentions - status.Mentions, err = s.state.DB.GetMentions(ctx, status.MentionIDs) + if !status.EmojisPopulated() { + // Status emojis are out-of-date with IDs, repopulate. + status.Emojis, err = s.state.DB.GetEmojisByIDs( + ctx, // these are already barebones + status.EmojiIDs, + ) if err != nil { - return nil, fmt.Errorf("error getting status mentions: %w", err) + return fmt.Errorf("error populating status emojis: %w", err) } } - return status, nil + return nil } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { @@ -239,12 +301,16 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er }) }) if err != nil { - // already processed return err } for _, id := range status.AttachmentIDs { - // Clear updated media attachment IDs from cache + // Invalidate media attachments from cache. + // + // NOTE: this is needed due to the way in which + // we upload status attachments, and only after + // update them with a known status ID. This is + // not the case for header/avatar attachments. s.state.Caches.GTS.Media().Invalidate("ID", id) } @@ -322,14 +388,19 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co return err } + // Invalidate status from database lookups. + s.state.Caches.GTS.Status().Invalidate("ID", status.ID) + for _, id := range status.AttachmentIDs { - // Clear updated media attachment IDs from cache + // Invalidate media attachments from cache. + // + // NOTE: this is needed due to the way in which + // we upload status attachments, and only after + // update them with a known status ID. This is + // not the case for header/avatar attachments. s.state.Caches.GTS.Media().Invalidate("ID", id) } - // Drop any old status value from cache by this ID - s.state.Caches.GTS.Status().Invalidate("ID", status.ID) - return nil } @@ -367,8 +438,12 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { return err } - // Drop any old value from cache by this ID + // Invalidate status from database lookups. s.state.Caches.GTS.Status().Invalidate("ID", id) + + // Invalidate status from all visibility lookups. + s.state.Caches.Visibility.Invalidate("ItemID", id) + return nil } diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go index c42ab249f..0f7e5df74 100644 --- a/internal/db/bundb/statusfave.go +++ b/internal/db/bundb/statusfave.go @@ -23,6 +23,7 @@ import ( "fmt" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" @@ -34,29 +35,82 @@ type statusFaveDB struct { state *state.State } -func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) { - fave := new(gtsmodel.StatusFave) +func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) { + return s.getStatusFave( + ctx, + "AccountID.StatusID", + func(fave *gtsmodel.StatusFave) error { + return s.conn. + NewSelect(). + Model(fave). + Where("? = ?", bun.Ident("account_id"), accountID). + Where("? = ?", bun.Ident("status_id"), statusID). + Scan(ctx) + }, + accountID, + statusID, + ) +} - err := s.conn. - NewSelect(). - Model(fave). - Where("? = ?", bun.Ident("status_fave.ID"), id). - Scan(ctx) +func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) { + return s.getStatusFave( + ctx, + "ID", + func(fave *gtsmodel.StatusFave) error { + return s.conn. + NewSelect(). + Model(fave). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + }, + id, + ) +} + +func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) { + // Fetch status fave from database cache with loader callback + fave, err := s.state.Caches.GTS.StatusFave().Load(lookup, func() (*gtsmodel.StatusFave, error) { + var fave gtsmodel.StatusFave + + // Not cached! Perform database query. + if err := dbQuery(&fave); err != nil { + return nil, s.conn.ProcessError(err) + } + + return &fave, nil + }, keyParts...) if err != nil { - return nil, s.conn.ProcessError(err) + return nil, err + } + + if gtscontext.Barebones(ctx) { + // no need to fully populate. + return fave, nil } - fave.Account, err = s.state.DB.GetAccountByID(ctx, fave.AccountID) + // Fetch the status fave author account. + fave.Account, err = s.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + fave.AccountID, + ) if err != nil { return nil, fmt.Errorf("error getting status fave account %q: %w", fave.AccountID, err) } - fave.TargetAccount, err = s.state.DB.GetAccountByID(ctx, fave.TargetAccountID) + // Fetch the status fave target account. + fave.TargetAccount, err = s.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + fave.TargetAccountID, + ) if err != nil { return nil, fmt.Errorf("error getting status fave target account %q: %w", fave.TargetAccountID, err) } - fave.Status, err = s.state.DB.GetStatusByID(ctx, fave.StatusID) + // Fetch the status fave target status. + fave.Status, err = s.state.DB.GetStatusByID( + gtscontext.SetBarebones(ctx), + fave.StatusID, + ) if err != nil { return nil, fmt.Errorf("error getting status fave status %q: %w", fave.StatusID, err) } @@ -64,38 +118,22 @@ func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel. return fave, nil } -func (s *statusFaveDB) GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) { - var id string - - err := s.conn. - NewSelect(). - TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). - Column("status_fave.id"). - Where("? = ?", bun.Ident("status_fave.account_id"), accountID). - Where("? = ?", bun.Ident("status_fave.status_id"), statusID). - Scan(ctx, &id) - if err != nil { - return nil, s.conn.ProcessError(err) - } - - return s.GetStatusFave(ctx, id) -} - -func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) { +func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) { ids := []string{} if err := s.conn. NewSelect(). - TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). - Column("status_fave.id"). - Where("? = ?", bun.Ident("status_fave.status_id"), statusID). + Table("status_faves"). + Column("id"). + Where("? = ?", bun.Ident("status_id"), statusID). Scan(ctx, &ids); err != nil { return nil, s.conn.ProcessError(err) } faves := make([]*gtsmodel.StatusFave, 0, len(ids)) + for _, id := range ids { - fave, err := s.GetStatusFave(ctx, id) + fave, err := s.GetStatusFaveByID(ctx, id) if err != nil { log.Errorf(ctx, "error getting status fave %q: %v", id, err) continue @@ -107,23 +145,27 @@ func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]* return faves, nil } -func (s *statusFaveDB) PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) db.Error { - _, err := s.conn. - NewInsert(). - Model(statusFave). - Exec(ctx) - - return s.conn.ProcessError(err) +func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) db.Error { + return s.state.Caches.GTS.StatusFave().Store(fave, func() error { + _, err := s.conn. + NewInsert(). + Model(fave). + Exec(ctx) + return s.conn.ProcessError(err) + }) } -func (s *statusFaveDB) DeleteStatusFave(ctx context.Context, id string) db.Error { - _, err := s.conn. +func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.Error { + if _, err := s.conn. NewDelete(). - TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). - Where("? = ?", bun.Ident("status_fave.id"), id). - Exec(ctx) + Table("status_faves"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return s.conn.ProcessError(err) + } - return s.conn.ProcessError(err) + s.state.Caches.GTS.StatusFave().Invalidate("ID", id) + return nil } func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) db.Error { @@ -131,42 +173,52 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set") } - // TODO: Capture fave IDs in a RETURNING - // statement (when faves have a cache), - // + use the IDs to invalidate cache entries. + // Capture fave IDs in a RETURNING statement. + var faveIDs []string q := s.conn. NewDelete(). - TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")) + Table("status_faves"). + Returning("?", bun.Ident("id")) if targetAccountID != "" { - q = q.Where("? = ?", bun.Ident("status_fave.target_account_id"), targetAccountID) + q = q.Where("? = ?", bun.Ident("target_account_id"), targetAccountID) } if originAccountID != "" { - q = q.Where("? = ?", bun.Ident("status_fave.account_id"), originAccountID) + q = q.Where("? = ?", bun.Ident("account_id"), originAccountID) } - if _, err := q.Exec(ctx); err != nil { + if _, err := q.Exec(ctx, &faveIDs); err != nil { return s.conn.ProcessError(err) } + for _, id := range faveIDs { + // Invalidate each of the returned status fave IDs. + s.state.Caches.GTS.StatusFave().Invalidate("ID", id) + } + return nil } func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) db.Error { - // TODO: Capture fave IDs in a RETURNING - // statement (when faves have a cache), - // + use the IDs to invalidate cache entries. + // Capture fave IDs in a RETURNING statement. + var faveIDs []string q := s.conn. NewDelete(). - TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). - Where("? = ?", bun.Ident("status_fave.status_id"), statusID) + Table("status_faves"). + Where("? = ?", bun.Ident("status_id"), statusID). + Returning("?", bun.Ident("id")) - if _, err := q.Exec(ctx); err != nil { + if _, err := q.Exec(ctx, &faveIDs); err != nil { return s.conn.ProcessError(err) } + for _, id := range faveIDs { + // Invalidate each of the returned status fave IDs. + s.state.Caches.GTS.StatusFave().Invalidate("ID", id) + } + return nil } diff --git a/internal/db/bundb/statusfave_test.go b/internal/db/bundb/statusfave_test.go index 98e495bf3..7218390bc 100644 --- a/internal/db/bundb/statusfave_test.go +++ b/internal/db/bundb/statusfave_test.go @@ -35,7 +35,7 @@ type StatusFaveTestSuite struct { func (suite *StatusFaveTestSuite) TestGetStatusFaves() { testStatus := suite.testStatuses["admin_account_status_1"] - faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID) + faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID) if err != nil { suite.FailNow(err.Error()) } @@ -51,7 +51,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaves() { func (suite *StatusFaveTestSuite) TestGetStatusFavesNone() { testStatus := suite.testStatuses["admin_account_status_4"] - faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID) + faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID) if err != nil { suite.FailNow(err.Error()) } @@ -63,7 +63,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaveByAccountID() { testAccount := suite.testAccounts["local_account_1"] testStatus := suite.testStatuses["admin_account_status_1"] - fave, err := suite.db.GetStatusFaveByAccountID(context.Background(), testAccount.ID, testStatus.ID) + fave, err := suite.db.GetStatusFave(context.Background(), testAccount.ID, testStatus.ID) suite.NoError(err) suite.NotNil(fave) } @@ -129,17 +129,17 @@ func (suite *StatusFaveTestSuite) TestDeleteStatusFave() { testFave := suite.testFaves["local_account_1_admin_account_status_1"] ctx := context.Background() - if err := suite.db.DeleteStatusFave(ctx, testFave.ID); err != nil { + if err := suite.db.DeleteStatusFaveByID(ctx, testFave.ID); err != nil { suite.FailNow(err.Error()) } - fave, err := suite.db.GetStatusFave(ctx, testFave.ID) + fave, err := suite.db.GetStatusFaveByID(ctx, testFave.ID) suite.ErrorIs(err, db.ErrNoEntries) suite.Nil(fave) } func (suite *StatusFaveTestSuite) TestDeleteStatusFaveNonExisting() { - err := suite.db.DeleteStatusFave(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G") + err := suite.db.DeleteStatusFaveByID(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G") suite.NoError(err) } diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index ea4a87d03..1ab140103 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -61,9 +61,12 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI Order("status.id DESC") if maxID == "" { + const future = 24 * time.Hour + var err error - // don't return statuses more than five minutes in the future - maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute)) + + // don't return statuses more than 24hr in the future + maxID, err = id.NewULIDFromTime(time.Now().Add(future)) if err != nil { return nil, err } @@ -138,15 +141,16 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). Column("status.id"). Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). - WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_id")). - WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")). WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")). Order("status.id DESC") if maxID == "" { + const future = 24 * time.Hour + var err error - // don't return statuses more than five minutes in the future - maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute)) + + // don't return statuses more than 24hr in the future + maxID, err = id.NewULIDFromTime(time.Now().Add(future)) if err != nil { return nil, err } diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index 5a447111c..d6632b38c 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -34,15 +34,32 @@ type TimelineTestSuite struct { } func (suite *TimelineTestSuite) TestGetPublicTimeline() { - ctx := context.Background() + var count int + + for _, status := range suite.testStatuses { + if status.Visibility == gtsmodel.VisibilityPublic && + status.BoostOfID == "" { + count++ + } + } + ctx := context.Background() s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) suite.NoError(err) - suite.Len(s, 6) + suite.Len(s, count) } func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() { + var count int + + for _, status := range suite.testStatuses { + if status.Visibility == gtsmodel.VisibilityPublic && + status.BoostOfID == "" { + count++ + } + } + ctx := context.Background() futureStatus := getFutureStatus() @@ -53,7 +70,7 @@ func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() { suite.NoError(err) suite.NotContains(s, futureStatus) - suite.Len(s, 6) + suite.Len(s, count) } func (suite *TimelineTestSuite) TestGetHomeTimeline() { diff --git a/internal/db/media.go b/internal/db/media.go index d86f9fe84..05609ba52 100644 --- a/internal/db/media.go +++ b/internal/db/media.go @@ -29,6 +29,9 @@ type Media interface { // GetAttachmentByID gets a single attachment by its ID. GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error) + // GetAttachmentsByIDs fetches a list of media attachments for given IDs. + GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) + // PutAttachment inserts the given attachment into the database. PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error diff --git a/internal/db/mention.go b/internal/db/mention.go index d66394a5d..348f946a2 100644 --- a/internal/db/mention.go +++ b/internal/db/mention.go @@ -30,4 +30,10 @@ type Mention interface { // GetMentions gets multiple mentions. GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error) + + // PutMention will insert the given mention into the database. + PutMention(ctx context.Context, mention *gtsmodel.Mention) error + + // DeleteMentionByID will delete mention with given ID from the database. + DeleteMentionByID(ctx context.Context, id string) error } diff --git a/internal/db/notification.go b/internal/db/notification.go index 18e40b4c1..fd3affe90 100644 --- a/internal/db/notification.go +++ b/internal/db/notification.go @@ -28,14 +28,17 @@ type Notification interface { // GetNotifications returns a slice of notifications that pertain to the given accountID. // // Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest). - GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error) + GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error) // GetNotification returns one notification according to its id. - GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error) + GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, Error) - // DeleteNotification deletes one notification according to its id, + // PutNotification will insert the given notification into the database. + PutNotification(ctx context.Context, notif *gtsmodel.Notification) error + + // DeleteNotificationByID deletes one notification according to its id, // and removes that notification from the in-memory cache. - DeleteNotification(ctx context.Context, id string) Error + DeleteNotificationByID(ctx context.Context, id string) Error // DeleteNotifications mass deletes notifications targeting targetAccountID // and/or originating from originAccountID. @@ -50,7 +53,7 @@ type Notification interface { // originate from originAccountID will be deleted. // // At least one parameter must not be an empty string. - DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) Error + DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) Error // DeleteNotificationsForStatus deletes all notifications that relate to // the given statusID. This function is useful when a status has been deleted, diff --git a/internal/db/relationship.go b/internal/db/relationship.go index d13a73dea..838647154 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -25,42 +25,86 @@ import ( // Relationship contains functions for getting or modifying the relationship between two accounts. type Relationship interface { - // IsBlocked checks whether account 1 has a block in place against account2. - // If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1. - IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error) + // IsBlocked checks whether source account has a block in place against target. + IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) + + // IsEitherBlocked checks whether there is a block in place between either of account1 and account2. + IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error) + + // GetBlockByID fetches block with given ID from the database. + GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error) + + // GetBlockByURI fetches block with given AP URI from the database. + GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error) // GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't. - // - // Because this is slower than Blocked, only use it if you need the actual Block struct for some reason, - // not if you're just checking for the existence of a block. - GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error) + GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, error) // PutBlock attempts to place the given account block in the database. - PutBlock(ctx context.Context, block *gtsmodel.Block) Error + PutBlock(ctx context.Context, block *gtsmodel.Block) error // DeleteBlockByID removes block with given ID from the database. - DeleteBlockByID(ctx context.Context, id string) Error + DeleteBlockByID(ctx context.Context, id string) error // DeleteBlockByURI removes block with given AP URI from the database. - DeleteBlockByURI(ctx context.Context, uri string) Error - - // DeleteBlocksByOriginAccountID removes any blocks with accountID equal to originAccountID. - DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) Error + DeleteBlockByURI(ctx context.Context, uri string) error - // DeleteBlocksByTargetAccountID removes any blocks with given targetAccountID. - DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) Error + // DeleteAccountBlocks will delete all database blocks to / from the given account ID. + DeleteAccountBlocks(ctx context.Context, accountID string) error // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount. GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error) + // GetFollowByID fetches follow with given ID from the database. + GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error) + + // GetFollowByURI fetches follow with given AP URI from the database. + GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error) + + // GetFollow retrieves a follow if it exists between source and target accounts. + GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) + + // GetFollowRequestByID fetches follow request with given ID from the database. + GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error) + + // GetFollowRequestByURI fetches follow request with given AP URI from the database. + GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error) + + // GetFollowRequest retrieves a follow request if it exists between source and target accounts. + GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) + // IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out. - IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) + IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) + + // IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out. + IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) // IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out. - IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) + IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) - // IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out. - IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error) + // PutFollow attempts to place the given account follow in the database. + PutFollow(ctx context.Context, follow *gtsmodel.Follow) error + + // PutFollowRequest attempts to place the given account follow request in the database. + PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error + + // DeleteFollowByID deletes a follow from the database with the given ID. + DeleteFollowByID(ctx context.Context, id string) error + + // DeleteFollowByURI deletes a follow from the database with the given URI. + DeleteFollowByURI(ctx context.Context, uri string) error + + // DeleteFollowRequestByID deletes a follow request from the database with the given ID. + DeleteFollowRequestByID(ctx context.Context, id string) error + + // DeleteFollowRequestByURI deletes a follow request from the database with the given URI. + DeleteFollowRequestByURI(ctx context.Context, uri string) error + + // DeleteAccountFollows will delete all database follows to / from the given account ID. + DeleteAccountFollows(ctx context.Context, accountID string) error + + // DeleteAccountFollowRequests will delete all database follow requests to / from the given account ID. + DeleteAccountFollowRequests(ctx context.Context, accountID string) error // AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table. // In other words, it should create the follow, and delete the existing follow request. @@ -69,65 +113,41 @@ type Relationship interface { AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error) // RejectFollowRequest fetches a follow request from the database, and then deletes it. - // - // The deleted follow request will be returned so that further processing can be done on it. - RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, Error) + RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) Error - // GetFollows returns a slice of follows owned by the given accountID, and/or - // targeting the given account id. - // - // If accountID is set and targetAccountID isn't, then all follows created by - // accountID will be returned. - // - // If targetAccountID is set and accountID isn't, then all follows targeting - // targetAccountID will be returned. - // - // If both accountID and targetAccountID are set, then only 0 or 1 follows will - // be in the returned slice. - GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, Error) + // GetAccountFollows returns a slice of follows owned by the given accountID. + GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) - // GetLocalFollowersIDs returns a list of local account IDs which follow the - // targetAccountID. The returned IDs are not guaranteed to be ordered in any - // particular way, so take care. - GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, Error) + // GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance. + GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) - // CountFollows is like GetFollows, but just counts rather than returning. - CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, Error) + // CountAccountFollows returns the amount of accounts that the given accountID is following. + CountAccountFollows(ctx context.Context, accountID string) (int, error) - // GetFollowRequests returns a slice of follows requests owned by the given - // accountID, and/or targeting the given account id. - // - // If accountID is set and targetAccountID isn't, then all requests created by - // accountID will be returned. - // - // If targetAccountID is set and accountID isn't, then all requests targeting - // targetAccountID will be returned. - // - // If both accountID and targetAccountID are set, then only 0 or 1 requests will - // be in the returned slice. - GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, Error) + // CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance. + CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) - // CountFollowRequests is like GetFollowRequests, but just counts rather than returning. - CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, Error) + // GetAccountFollowers fetches follows that target given accountID. + GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) - // Unfollow removes a follow targeting targetAccountID and originating - // from originAccountID. - // - // If a follow was removed this way, the AP URI of the follow will be - // returned to the caller, so that further processing can take place - // if necessary. - // - // If no follow was removed this way, the returned string will be empty. - Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, Error) + // GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance. + GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) - // UnfollowRequest removes a follow request targeting targetAccountID - // and originating from originAccountID. - // - // If a follow request was removed this way, the AP URI of the follow - // request will be returned to the caller, so that further processing - // can take place if necessary. - // - // If no follow request was removed this way, the returned string will - // be empty. - UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, Error) + // CountAccountFollowers returns the amounts that the given ID is followed by. + CountAccountFollowers(ctx context.Context, accountID string) (int, error) + + // CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance. + CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) + + // GetAccountFollowRequests returns all follow requests targeting the given account. + GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) + + // GetAccountFollowRequesting returns all follow requests originating from the given account. + GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) + + // CountAccountFollowRequests returns number of follow requests targeting the given account. + CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) + + // CountAccountFollowerRequests returns number of follow requests originating from the given account. + CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) } diff --git a/internal/db/status.go b/internal/db/status.go index 16728983a..fdce19094 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -37,6 +37,9 @@ type Status interface { // GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error) + // PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc). + PopulateStatus(ctx context.Context, status *gtsmodel.Status) error + // PutStatus stores one status in the database. PutStatus(ctx context.Context, status *gtsmodel.Status) Error diff --git a/internal/db/statusfave.go b/internal/db/statusfave.go index 2d55592aa..b435da514 100644 --- a/internal/db/statusfave.go +++ b/internal/db/statusfave.go @@ -24,22 +24,22 @@ import ( ) type StatusFave interface { - // GetStatusFave returns one status fave with the given id. - GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, Error) - // GetStatusFaveByAccountID gets one status fave created by the given // accountID, targeting the given statusID. - GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error) + GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error) + + // GetStatusFave returns one status fave with the given id. + GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, Error) // GetStatusFaves returns a slice of faves/likes of the given status. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error) + GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error) // PutStatusFave inserts the given statusFave into the database. PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) Error // DeleteStatusFave deletes one status fave with the given id. - DeleteStatusFave(ctx context.Context, id string) Error + DeleteStatusFaveByID(ctx context.Context, id string) Error // DeleteStatusFaves mass deletes status faves targeting targetAccountID // and/or originating from originAccountID and/or faving statusID. |