summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/account.go15
-rw-r--r--internal/db/bundb/account.go147
-rw-r--r--internal/db/bundb/account_test.go179
-rw-r--r--internal/db/bundb/basic_test.go11
-rw-r--r--internal/db/bundb/media.go42
-rw-r--r--internal/db/bundb/mention.go71
-rw-r--r--internal/db/bundb/migrations/20230328105630_chore_refactoring.go167
-rw-r--r--internal/db/bundb/notification.go25
-rw-r--r--internal/db/bundb/notification_test.go20
-rw-r--r--internal/db/bundb/relationship.go685
-rw-r--r--internal/db/bundb/relationship_block.go218
-rw-r--r--internal/db/bundb/relationship_follow.go243
-rw-r--r--internal/db/bundb/relationship_follow_req.go293
-rw-r--r--internal/db/bundb/relationship_test.go493
-rw-r--r--internal/db/bundb/status.go179
-rw-r--r--internal/db/bundb/statusfave.go170
-rw-r--r--internal/db/bundb/statusfave_test.go12
-rw-r--r--internal/db/bundb/timeline.go16
-rw-r--r--internal/db/bundb/timeline_test.go23
-rw-r--r--internal/db/media.go3
-rw-r--r--internal/db/mention.go6
-rw-r--r--internal/db/notification.go13
-rw-r--r--internal/db/relationship.go166
-rw-r--r--internal/db/status.go3
-rw-r--r--internal/db/statusfave.go12
25 files changed, 2261 insertions, 951 deletions
diff --git a/internal/db/account.go b/internal/db/account.go
index 6ecfea018..4a08918b0 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -41,6 +41,21 @@ type Account interface {
// GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong.
GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error)
+ // GetAccountByInboxURI returns one account with the given inbox_uri, or an error if something goes wrong.
+ GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+
+ // GetAccountByOutboxURI returns one account with the given outbox_uri, or an error if something goes wrong.
+ GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+
+ // GetAccountByFollowingURI returns one account with the given following_uri, or an error if something goes wrong.
+ GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+
+ // GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong.
+ GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+
+ // PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc).
+ PopulateAccount(ctx context.Context, account *gtsmodel.Account) error
+
// PutAccount puts one account in the database.
PutAccount(ctx context.Context, account *gtsmodel.Account) Error
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index df73168e2..ccf7aaa46 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -20,11 +20,13 @@ package bundb
import (
"context"
"errors"
+ "fmt"
"strings"
"time"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@@ -37,18 +39,15 @@ type accountDB struct {
state *state.State
}
-func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
- return a.conn.
- NewSelect().
- Model(account)
-}
-
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"ID",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.id"), id).
+ Scan(ctx)
},
id,
)
@@ -59,7 +58,10 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
ctx,
"URI",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.uri"), uri).
+ Scan(ctx)
},
uri,
)
@@ -70,7 +72,10 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
ctx,
"URL",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.url"), url).
+ Scan(ctx)
},
url,
)
@@ -81,7 +86,8 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
ctx,
"Username.Domain",
func(account *gtsmodel.Account) error {
- q := a.newAccountQ(account)
+ q := a.conn.NewSelect().
+ Model(account)
if domain != "" {
q = q.
@@ -105,12 +111,71 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
ctx,
"PublicKeyURI",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.public_key_uri"), id).
+ Scan(ctx)
},
id,
)
}
+func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "InboxURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.inbox_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "OutboxURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.outbox_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "FollowersURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.followers_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "FollowingURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.following_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
var username string
@@ -141,31 +206,56 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(
return nil, err
}
- if account.AvatarMediaAttachmentID != "" {
- // Set the account's related avatar
- account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return account, nil
+ }
+
+ // Further populate the account fields where applicable.
+ if err := a.PopulateAccount(ctx, account); err != nil {
+ return nil, err
+ }
+
+ return account, nil
+}
+
+func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Account) error {
+ var err error
+
+ if account.AvatarMediaAttachment == nil && account.AvatarMediaAttachmentID != "" {
+ // Account avatar attachment is not set, fetch from database.
+ account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(
+ ctx, // these are already barebones
+ account.AvatarMediaAttachmentID,
+ )
if err != nil {
- log.Errorf(ctx, "error getting account %s avatar: %v", account.ID, err)
+ return fmt.Errorf("error populating account avatar: %w", err)
}
}
- if account.HeaderMediaAttachmentID != "" {
- // Set the account's related header
- account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.HeaderMediaAttachmentID)
+ if account.HeaderMediaAttachment == nil && account.HeaderMediaAttachmentID != "" {
+ // Account header attachment is not set, fetch from database.
+ account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(
+ ctx, // these are already barebones
+ account.HeaderMediaAttachmentID,
+ )
if err != nil {
- log.Errorf(ctx, "error getting account %s header: %v", account.ID, err)
+ return fmt.Errorf("error populating account header: %w", err)
}
}
- if len(account.EmojiIDs) > 0 {
- // Set the account's related emojis
- account.Emojis, err = a.state.DB.GetEmojisByIDs(ctx, account.EmojiIDs)
+ if !account.EmojisPopulated() {
+ // Account emojis are out-of-date with IDs, repopulate.
+ account.Emojis, err = a.state.DB.GetEmojisByIDs(
+ ctx, // these are already barebones
+ account.EmojiIDs,
+ )
if err != nil {
- log.Errorf(ctx, "error getting account %s emojis: %v", account.ID, err)
+ return fmt.Errorf("error populating account emojis: %w", err)
}
}
- return account, nil
+ return nil
}
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
@@ -198,7 +288,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
columns = append(columns, "updated_at")
}
- return a.state.Caches.GTS.Account().Store(account, func() error {
+ err := a.state.Caches.GTS.Account().Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@@ -234,6 +324,11 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
return err
})
})
+ if err != nil {
+ return err
+ }
+
+ return nil
}
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
@@ -258,7 +353,9 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
return err
}
+ // Invalidate account from database lookups.
a.state.Caches.GTS.Account().Invalidate("ID", id)
+
return nil
}
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
index b7e8aaadc..2241ab783 100644
--- a/internal/db/bundb/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -21,6 +21,8 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
+ "errors"
+ "reflect"
"strings"
"testing"
"time"
@@ -61,44 +63,149 @@ func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() {
suite.Len(statuses, 1)
}
-func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
- account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
- if err != nil {
- suite.FailNow(err.Error())
+func (suite *AccountTestSuite) TestGetAccountBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Sentinel error to mark avoiding a test case.
+ sentinelErr := errors.New("sentinel")
+
+ // isEqual checks if 2 account models are equal.
+ isEqual := func(a1, a2 gtsmodel.Account) bool {
+ // Clear populated sub-models.
+ a1.HeaderMediaAttachment = nil
+ a2.HeaderMediaAttachment = nil
+ a1.AvatarMediaAttachment = nil
+ a2.AvatarMediaAttachment = nil
+ a1.Emojis = nil
+ a2.Emojis = nil
+
+ // Clear database-set fields.
+ a1.CreatedAt = time.Time{}
+ a2.CreatedAt = time.Time{}
+ a1.UpdatedAt = time.Time{}
+ a2.UpdatedAt = time.Time{}
+
+ // Manually compare keys.
+ pk1 := a1.PublicKey
+ pv1 := a1.PrivateKey
+ pk2 := a2.PublicKey
+ pv2 := a2.PrivateKey
+ a1.PublicKey = nil
+ a1.PrivateKey = nil
+ a2.PublicKey = nil
+ a2.PrivateKey = nil
+
+ return reflect.DeepEqual(a1, a2) &&
+ ((pk1 == nil && pk2 == nil) || pk1.Equal(pk2)) &&
+ ((pv1 == nil && pv2 == nil) || pv1.Equal(pv2))
}
- suite.NotNil(account)
- suite.NotNil(account.AvatarMediaAttachment)
- suite.NotEmpty(account.AvatarMediaAttachment.URL)
- suite.NotNil(account.HeaderMediaAttachment)
- suite.NotEmpty(account.HeaderMediaAttachment.URL)
-}
-
-func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {
- testAccount1 := suite.testAccounts["local_account_1"]
- account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain)
- suite.NoError(err)
- suite.NotNil(account1)
-
- testAccount2 := suite.testAccounts["remote_account_1"]
- account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain)
- suite.NoError(err)
- suite.NotNil(account2)
-}
-
-func (suite *AccountTestSuite) TestGetAccountByUsernameDomainMixedCase() {
- testAccount := suite.testAccounts["remote_account_2"]
- account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount.Username, testAccount.Domain)
- suite.NoError(err)
- suite.NotNil(account1)
-
- account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToUpper(testAccount.Username), testAccount.Domain)
- suite.NoError(err)
- suite.NotNil(account2)
-
- account3, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToLower(testAccount.Username), testAccount.Domain)
- suite.NoError(err)
- suite.NotNil(account3)
+ for _, account := range suite.testAccounts {
+ for lookup, dbfunc := range map[string]func() (*gtsmodel.Account, error){
+ "id": func() (*gtsmodel.Account, error) {
+ return suite.db.GetAccountByID(ctx, account.ID)
+ },
+
+ "uri": func() (*gtsmodel.Account, error) {
+ return suite.db.GetAccountByURI(ctx, account.URI)
+ },
+
+ "url": func() (*gtsmodel.Account, error) {
+ if account.URL == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByURL(ctx, account.URL)
+ },
+
+ "username@domain": func() (*gtsmodel.Account, error) {
+ return suite.db.GetAccountByUsernameDomain(ctx, account.Username, account.Domain)
+ },
+
+ "username_upper@domain": func() (*gtsmodel.Account, error) {
+ return suite.db.GetAccountByUsernameDomain(ctx, strings.ToUpper(account.Username), account.Domain)
+ },
+
+ "username_lower@domain": func() (*gtsmodel.Account, error) {
+ return suite.db.GetAccountByUsernameDomain(ctx, strings.ToLower(account.Username), account.Domain)
+ },
+
+ "public_key_uri": func() (*gtsmodel.Account, error) {
+ if account.PublicKeyURI == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByPubkeyID(ctx, account.PublicKeyURI)
+ },
+
+ "inbox_uri": func() (*gtsmodel.Account, error) {
+ if account.InboxURI == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByInboxURI(ctx, account.InboxURI)
+ },
+
+ "outbox_uri": func() (*gtsmodel.Account, error) {
+ if account.OutboxURI == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByOutboxURI(ctx, account.OutboxURI)
+ },
+
+ "following_uri": func() (*gtsmodel.Account, error) {
+ if account.FollowingURI == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByFollowingURI(ctx, account.FollowingURI)
+ },
+
+ "followers_uri": func() (*gtsmodel.Account, error) {
+ if account.FollowersURI == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByFollowersURI(ctx, account.FollowersURI)
+ },
+ } {
+
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ checkAcc, err := dbfunc()
+ if err != nil {
+ if err == sentinelErr {
+ continue
+ }
+
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Check received account data.
+ if !isEqual(*checkAcc, *account) {
+ t.Errorf("account does not contain expected data: %+v", checkAcc)
+ continue
+ }
+
+ // Check that avatar attachment populated.
+ if account.AvatarMediaAttachmentID != "" &&
+ (checkAcc.AvatarMediaAttachment == nil || checkAcc.AvatarMediaAttachment.ID != account.AvatarMediaAttachmentID) {
+ t.Errorf("account avatar media attachment not correctly populated for: %+v", account)
+ continue
+ }
+
+ // Check that header attachment populated.
+ if account.HeaderMediaAttachmentID != "" &&
+ (checkAcc.HeaderMediaAttachment == nil || checkAcc.HeaderMediaAttachment.ID != account.HeaderMediaAttachmentID) {
+ t.Errorf("account header media attachment not correctly populated for: %+v", account)
+ continue
+ }
+ }
+ }
}
func (suite *AccountTestSuite) TestUpdateAccount() {
diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go
index f238f2273..a24deac9e 100644
--- a/internal/db/bundb/basic_test.go
+++ b/internal/db/bundb/basic_test.go
@@ -19,6 +19,8 @@ package bundb_test
import (
"context"
+ "crypto/rand"
+ "crypto/rsa"
"testing"
"time"
@@ -40,6 +42,12 @@ func (suite *BasicTestSuite) TestGetAccountByID() {
}
func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Create an account that only just matches constraints.
testAccount := &gtsmodel.Account{
ID: "01GADR1AH9VCKH8YYCM86XSZ00",
Username: "test",
@@ -49,6 +57,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
OutboxURI: "https://example.org/users/test/outbox",
ActorType: "Person",
PublicKeyURI: "https://example.org/test#main-key",
+ PublicKey: &key.PublicKey,
}
if err := suite.db.Put(context.Background(), testAccount); err != nil {
@@ -99,7 +108,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
suite.Empty(a.FeaturedCollectionURI)
suite.Equal(testAccount.ActorType, a.ActorType)
suite.Nil(a.PrivateKey)
- suite.Nil(a.PublicKey)
+ suite.EqualValues(key.PublicKey, *a.PublicKey)
suite.Equal(testAccount.PublicKeyURI, a.PublicKeyURI)
suite.Zero(a.SensitizedAt)
suite.Zero(a.SilencedAt)
diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go
index 1a9d3be05..d17d64b35 100644
--- a/internal/db/bundb/media.go
+++ b/internal/db/bundb/media.go
@@ -47,6 +47,24 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
)
}
+func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) {
+ attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
+
+ for _, id := range ids {
+ // Attempt fetch from DB
+ attachment, err := m.GetAttachmentByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error getting attachment %q: %v", id, err)
+ continue
+ }
+
+ // Append attachment
+ attachments = append(attachments, attachment)
+ }
+
+ return attachments, nil
+}
+
func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, db.Error) {
return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) {
var attachment gtsmodel.MediaAttachment
@@ -118,7 +136,7 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l
return nil, m.conn.ProcessError(err)
}
- return m.getAttachments(ctx, attachmentIDs)
+ return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
@@ -163,7 +181,7 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
return nil, m.conn.ProcessError(err)
}
- return m.getAttachments(ctx, attachmentIDs)
+ return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) {
@@ -189,7 +207,7 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
return nil, m.conn.ProcessError(err)
}
- return m.getAttachments(ctx, attachmentIDs)
+ return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
@@ -211,21 +229,3 @@ func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan t
return count, nil
}
-
-func (m *mediaDB) getAttachments(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, db.Error) {
- attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
-
- for _, id := range ids {
- // Attempt fetch from DB
- attachment, err := m.GetAttachmentByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting attachment %q: %v", id, err)
- continue
- }
-
- // Append attachment
- attachments = append(attachments, attachment)
- }
-
- return attachments, nil
-}
diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go
index 3a543f3c2..e64d6dac4 100644
--- a/internal/db/bundb/mention.go
+++ b/internal/db/bundb/mention.go
@@ -19,8 +19,10 @@ package bundb
import (
"context"
+ "fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@@ -32,20 +34,13 @@ type mentionDB struct {
state *state.State
}
-func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
- return m.conn.
- NewSelect().
- Model(i).
- Relation("Status").
- Relation("OriginAccount").
- Relation("TargetAccount")
-}
-
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
- return m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
+ mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
var mention gtsmodel.Mention
- q := m.newMentionQ(&mention).
+ q := m.conn.
+ NewSelect().
+ Model(&mention).
Where("? = ?", bun.Ident("mention.id"), id)
if err := q.Scan(ctx); err != nil {
@@ -54,6 +49,38 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
return &mention, nil
}, id)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set the mention originating status.
+ mention.Status, err = m.state.DB.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ mention.StatusID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error populating mention status: %w", err)
+ }
+
+ // Set the mention origin account model.
+ mention.OriginAccount, err = m.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ mention.OriginAccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error populating mention origin account: %w", err)
+ }
+
+ // Set the mention target account model.
+ mention.TargetAccount, err = m.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ mention.TargetAccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error populating mention target account: %w", err)
+ }
+
+ return mention, nil
}
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
@@ -73,3 +100,25 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
return mentions, nil
}
+
+func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
+ return m.state.Caches.GTS.Mention().Store(mention, func() error {
+ _, err := m.conn.NewInsert().Model(mention).Exec(ctx)
+ return m.conn.ProcessError(err)
+ })
+}
+
+func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error {
+ if _, err := m.conn.
+ NewDelete().
+ Table("mentions").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return m.conn.ProcessError(err)
+ }
+
+ // Invalidate mention from the lookup cache.
+ m.state.Caches.GTS.Mention().Invalidate("ID", id)
+
+ return nil
+}
diff --git a/internal/db/bundb/migrations/20230328105630_chore_refactoring.go b/internal/db/bundb/migrations/20230328105630_chore_refactoring.go
new file mode 100644
index 000000000..3bf9d59ef
--- /dev/null
+++ b/internal/db/bundb/migrations/20230328105630_chore_refactoring.go
@@ -0,0 +1,167 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ // To update unique constraint on public key, we need to migrate accounts into a new table.
+ // See section 7 here: https://www.sqlite.org/lang_altertable.html
+
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Create the new accounts table.
+ if _, err := tx.
+ NewCreateTable().
+ ModelTableExpr("new_accounts").
+ Model(&gtsmodel.Account{}).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // If we don't specify columns explicitly,
+ // Postgres gives the following error when
+ // transferring accounts to new_accounts:
+ //
+ // ERROR: column "fetched_at" is of type timestamp with time zone but expression is of type character varying at character 35
+ // HINT: You will need to rewrite or cast the expression.
+ //
+ // Rather than do funky casting to fix this,
+ // it's simpler to just specify all columns.
+ columns := []string{
+ "id",
+ "created_at",
+ "updated_at",
+ "fetched_at",
+ "username",
+ "domain",
+ "avatar_media_attachment_id",
+ "avatar_remote_url",
+ "header_media_attachment_id",
+ "header_remote_url",
+ "display_name",
+ "emojis",
+ "fields",
+ "note",
+ "note_raw",
+ "memorial",
+ "also_known_as",
+ "moved_to_account_id",
+ "bot",
+ "reason",
+ "locked",
+ "discoverable",
+ "privacy",
+ "sensitive",
+ "language",
+ "status_content_type",
+ "custom_css",
+ "uri",
+ "url",
+ "inbox_uri",
+ "shared_inbox_uri",
+ "outbox_uri",
+ "following_uri",
+ "followers_uri",
+ "featured_collection_uri",
+ "actor_type",
+ "private_key",
+ "public_key",
+ "public_key_uri",
+ "sensitized_at",
+ "silenced_at",
+ "suspended_at",
+ "hide_collections",
+ "suspension_origin",
+ "enable_rss",
+ }
+
+ // Copy all accounts to the new table.
+ if _, err := tx.
+ NewInsert().
+ Table("new_accounts").
+ Table("accounts").
+ Column(columns...).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Drop the old table.
+ if _, err := tx.
+ NewDropTable().
+ Table("accounts").
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Rename new table to old table.
+ if _, err := tx.
+ ExecContext(
+ ctx,
+ "ALTER TABLE ? RENAME TO ?",
+ bun.Ident("new_accounts"),
+ bun.Ident("accounts"),
+ ); err != nil {
+ return err
+ }
+
+ // Add all account indexes to the new table.
+ for index, columns := range map[string][]string{
+ // Standard indices.
+ "accounts_id_idx": {"id"},
+ "accounts_suspended_at_idx": {"suspended_at"},
+ "accounts_domain_idx": {"domain"},
+ "accounts_username_domain_idx": {"username", "domain"},
+ // URI indices.
+ "accounts_uri_idx": {"uri"},
+ "accounts_url_idx": {"url"},
+ "accounts_inbox_uri_idx": {"inbox_uri"},
+ "accounts_outbox_uri_idx": {"outbox_uri"},
+ "accounts_followers_uri_idx": {"followers_uri"},
+ "accounts_following_uri_idx": {"following_uri"},
+ "accounts_public_key_uri_idx": {"public_key_uri"},
+ } {
+ if _, err := tx.
+ NewCreateIndex().
+ Table("accounts").
+ Index(index).
+ Column(columns...).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go
index b1e7f45ff..f32aed092 100644
--- a/internal/db/bundb/notification.go
+++ b/internal/db/bundb/notification.go
@@ -33,7 +33,7 @@ type notificationDB struct {
state *state.State
}
-func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
+func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) {
var notif gtsmodel.Notification
@@ -48,7 +48,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo
}, id)
}
-func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
+func (n *notificationDB) GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@@ -92,7 +92,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
// reason for this is that for each notif, we can instead get it from our cache if it's cached
for _, id := range notifIDs {
// Attempt fetch from DB
- notif, err := n.GetNotification(ctx, id)
+ notif, err := n.GetNotificationByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting notification %q: %v", id, err)
continue
@@ -105,7 +105,14 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
return notifs, nil
}
-func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.Error {
+func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error {
+ return n.state.Caches.GTS.Notification().Store(notif, func() error {
+ _, err := n.conn.NewInsert().Model(notif).Exec(ctx)
+ return n.conn.ProcessError(err)
+ })
+}
+
+func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) db.Error {
if _, err := n.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
@@ -118,19 +125,23 @@ func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.E
return nil
}
-func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
+func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) db.Error {
if targetAccountID == "" && originAccountID == "" {
return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set")
}
// Capture notification IDs in a RETURNING statement.
- ids := []string{}
+ var ids []string
q := n.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Returning("?", bun.Ident("id"))
+ if len(types) > 0 {
+ q = q.Where("? IN (?)", bun.Ident("notification.notification_type"), bun.In(types))
+ }
+
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID)
}
@@ -153,7 +164,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountI
func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) db.Error {
// Capture notification IDs in a RETURNING statement.
- ids := []string{}
+ var ids []string
q := n.conn.
NewDelete().
diff --git a/internal/db/bundb/notification_test.go b/internal/db/bundb/notification_test.go
index 117fc329c..bdee911b3 100644
--- a/internal/db/bundb/notification_test.go
+++ b/internal/db/bundb/notification_test.go
@@ -85,11 +85,11 @@ type NotificationTestSuite struct {
BunDBStandardTestSuite
}
-func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() {
+func (suite *NotificationTestSuite) TestGetAccountNotificationsWithSpam() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
before := time.Now()
- notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
+ notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
timeTaken := time.Since(before)
fmt.Printf("\n\n\n withSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken)
@@ -100,10 +100,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() {
}
}
-func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() {
+func (suite *NotificationTestSuite) TestGetAccountNotificationsWithoutSpam() {
testAccount := suite.testAccounts["local_account_1"]
before := time.Now()
- notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
+ notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
timeTaken := time.Since(before)
fmt.Printf("\n\n\n withoutSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken)
@@ -117,10 +117,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() {
func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
- err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "")
+ err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "")
suite.NoError(err)
- notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
+ notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
suite.NotNil(notifications)
suite.Empty(notifications)
@@ -129,10 +129,10 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
- err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "")
+ err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "")
suite.NoError(err)
- notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
+ notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
suite.NotNil(notifications)
suite.Empty(notifications)
@@ -146,7 +146,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() {
func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAccount() {
testAccount := suite.testAccounts["local_account_2"]
- if err := suite.db.DeleteNotifications(context.Background(), "", testAccount.ID); err != nil {
+ if err := suite.db.DeleteNotifications(context.Background(), nil, "", testAccount.ID); err != nil {
suite.FailNow(err.Error())
}
@@ -166,7 +166,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAndTar
originAccount := suite.testAccounts["local_account_2"]
targetAccount := suite.testAccounts["admin_account"]
- if err := suite.db.DeleteNotifications(context.Background(), targetAccount.ID, originAccount.ID); err != nil {
+ if err := suite.db.DeleteNotifications(context.Background(), nil, targetAccount.ID, originAccount.ID); err != nil {
suite.FailNow(err.Error())
}
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index 21a29b5dc..82559a213 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -23,8 +23,8 @@ import (
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
@@ -34,603 +34,212 @@ type relationshipDB struct {
state *state.State
}
-func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
- // Look for a block in direction of account1->account2
- block1, err := r.getBlock(ctx, account1, account2)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return false, err
- }
-
- if block1 != nil {
- // account1 blocks account2
- return true, nil
- } else if !eitherDirection {
- // Don't check for mutli-directional
- return false, nil
- }
-
- // Look for a block in direction of account2->account1
- block2, err := r.getBlock(ctx, account2, account1)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return false, err
- }
-
- return (block2 != nil), nil
-}
-
-func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
- // Fetch block from database
- block, err := r.getBlock(ctx, account1, account2)
- if err != nil {
- return nil, err
- }
-
- // Set the block originating account
- block.Account, err = r.state.DB.GetAccountByID(ctx, block.AccountID)
- if err != nil {
- return nil, err
- }
-
- // Set the block target account
- block.TargetAccount, err = r.state.DB.GetAccountByID(ctx, block.TargetAccountID)
- if err != nil {
- return nil, err
- }
-
- return block, nil
-}
-
-func (r *relationshipDB) getBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
- return r.state.Caches.GTS.Block().Load("AccountID.TargetAccountID", func() (*gtsmodel.Block, error) {
- var block gtsmodel.Block
-
- q := r.conn.NewSelect().Model(&block).
- Where("? = ?", bun.Ident("block.account_id"), account1).
- Where("? = ?", bun.Ident("block.target_account_id"), account2)
- if err := q.Scan(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- return &block, nil
- }, account1, account2)
-}
-
-func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) db.Error {
- return r.state.Caches.GTS.Block().Store(block, func() error {
- _, err := r.conn.NewInsert().Model(block).Exec(ctx)
- return r.conn.ProcessError(err)
- })
-}
-
-func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) db.Error {
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Where("? = ?", bun.Ident("block.id"), id).
- Exec(ctx); err != nil {
- return r.conn.ProcessError(err)
- }
-
- // Drop any old value from cache by this ID
- r.state.Caches.GTS.Block().Invalidate("ID", id)
- return nil
-}
-
-func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) db.Error {
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Where("? = ?", bun.Ident("block.uri"), uri).
- Exec(ctx); err != nil {
- return r.conn.ProcessError(err)
- }
-
- // Drop any old value from cache by this URI
- r.state.Caches.GTS.Block().Invalidate("URI", uri)
- return nil
-}
-
-func (r *relationshipDB) DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) db.Error {
- blockIDs := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Column("block.id").
- Where("? = ?", bun.Ident("block.account_id"), originAccountID)
-
- if err := q.Scan(ctx, &blockIDs); err != nil {
- return r.conn.ProcessError(err)
- }
-
- for _, blockID := range blockIDs {
- if err := r.DeleteBlockByID(ctx, blockID); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (r *relationshipDB) DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) db.Error {
- blockIDs := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Column("block.id").
- Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID)
-
- if err := q.Scan(ctx, &blockIDs); err != nil {
- return r.conn.ProcessError(err)
- }
-
- for _, blockID := range blockIDs {
- if err := r.DeleteBlockByID(ctx, blockID); err != nil {
- return err
- }
- }
-
- return nil
-}
-
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
- rel := &gtsmodel.Relationship{
- ID: targetAccount,
+ var rel gtsmodel.Relationship
+ rel.ID = targetAccount
+
+ // check if the requesting follows the target
+ follow, err := r.GetFollow(
+ gtscontext.SetBarebones(ctx),
+ requestingAccount,
+ targetAccount,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err)
}
- // check if the requesting account follows the target account
- follow := &gtsmodel.Follow{}
- if err := r.conn.
- NewSelect().
- Model(follow).
- Column("follow.show_reblogs", "follow.notify").
- Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
- Limit(1).
- Scan(ctx); err != nil {
- if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
- return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
- }
- // no follow exists so these are all false
- rel.Following = false
- rel.ShowingReblogs = false
- rel.Notifying = false
- } else {
+ if follow != nil {
// follow exists so we can fill these fields out...
rel.Following = true
rel.ShowingReblogs = *follow.ShowReblogs
rel.Notifying = *follow.Notify
}
- // check if the target account follows the requesting account
- followedByQ := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id").
- Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
- Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
- followedBy, err := r.conn.Exists(ctx, followedByQ)
+ // check if the target follows the requesting
+ rel.FollowedBy, err = r.IsFollowing(ctx,
+ targetAccount,
+ requestingAccount,
+ )
if err != nil {
- return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err)
}
- rel.FollowedBy = followedBy
- // check if there's a pending following request from requesting account to target account
- requestedQ := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
- requested, err := r.conn.Exists(ctx, requestedQ)
+ // check if requesting has follow requested target
+ rel.Requested, err = r.IsFollowRequested(ctx,
+ requestingAccount,
+ targetAccount,
+ )
if err != nil {
- return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err)
}
- rel.Requested = requested
// check if the requesting account is blocking the target account
- blockA2T, err := r.getBlock(ctx, requestingAccount, targetAccount)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
- }
- rel.Blocking = (blockA2T != nil)
-
- // check if the requesting account is blocked by the target account
- blockT2A, err := r.getBlock(ctx, targetAccount, requestingAccount)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
- }
- rel.BlockedBy = (blockT2A != nil)
-
- return rel, nil
-}
-
-func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
- if sourceAccount == nil || targetAccount == nil {
- return false, nil
- }
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id").
- Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
-
- return r.conn.Exists(ctx, q)
-}
-
-func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
- if sourceAccount == nil || targetAccount == nil {
- return false, nil
- }
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
-
- return r.conn.Exists(ctx, q)
-}
-
-func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
- if account1 == nil || account2 == nil {
- return false, nil
- }
-
- // make sure account 1 follows account 2
- f1, err := r.IsFollowing(ctx, account1, account2)
+ rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount)
if err != nil {
- return false, err
+ return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err)
}
- // make sure account 2 follows account 1
- f2, err := r.IsFollowing(ctx, account2, account1)
+ // check if the requesting account is blocked by the target account
+ rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount)
if err != nil {
- return false, err
+ return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err)
}
- return f1 && f2, nil
+ return &rel, nil
}
-func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
- // Get original follow request.
- var followRequestID string
- if err := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
- Scan(ctx, &followRequestID); err != nil {
+func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectFollows(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
-
- followRequest, err := r.getFollowRequest(ctx, followRequestID)
- if err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- // Create a new follow to 'replace'
- // the original follow request with.
- follow := &gtsmodel.Follow{
- ID: followRequest.ID,
- AccountID: originAccountID,
- Account: followRequest.Account,
- TargetAccountID: targetAccountID,
- TargetAccount: followRequest.TargetAccount,
- URI: followRequest.URI,
- }
-
- // If the follow already exists, just
- // replace the URI with the new one.
- if _, err := r.conn.
- NewInsert().
- Model(follow).
- On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
- Exec(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- // Delete original follow request.
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
- Exec(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- // Delete original follow request notification.
- if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
- return nil, err
- }
-
- // return the new follow
- return follow, nil
+ return r.GetFollowsByIDs(ctx, followIDs)
}
-func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
- // Get original follow request.
- var followRequestID string
- if err := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
- Scan(ctx, &followRequestID); err != nil {
+func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectLocalFollows(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
+ return r.GetFollowsByIDs(ctx, followIDs)
+}
- followRequest, err := r.getFollowRequest(ctx, followRequestID)
- if err != nil {
+func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectFollowers(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
+ return r.GetFollowsByIDs(ctx, followIDs)
+}
- // Delete original follow request.
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
- Exec(ctx); err != nil {
+func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectLocalFollowers(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
-
- // Delete original follow request notification.
- if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
- return nil, err
- }
-
- // Return the now deleted follow request.
- return followRequest, nil
+ return r.GetFollowsByIDs(ctx, followIDs)
}
-func (r *relationshipDB) deleteFollowRequestNotif(ctx context.Context, originAccountID string, targetAccountID string) db.Error {
- var id string
- if err := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
- Column("notification.id").
- Where("? = ?", bun.Ident("notification.origin_account_id"), originAccountID).
- Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID).
- Where("? = ?", bun.Ident("notification.notification_type"), gtsmodel.NotificationFollowRequest).
- Limit(1). // There should only be one!
- Scan(ctx, &id); err != nil {
- err = r.conn.ProcessError(err)
- if errors.Is(err, db.ErrNoEntries) {
- // If no entries, the notif didn't
- // exist anyway so nothing to do here.
- return nil
- }
- // Return on real error.
- return err
- }
-
- return r.state.DB.DeleteNotification(ctx, id)
+func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollows(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
}
-func (r *relationshipDB) getFollow(ctx context.Context, id string) (*gtsmodel.Follow, db.Error) {
- follow := &gtsmodel.Follow{}
-
- err := r.conn.
- NewSelect().
- Model(follow).
- Where("? = ?", bun.Ident("follow.id"), id).
- Scan(ctx)
- if err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- follow.Account, err = r.state.DB.GetAccountByID(ctx, follow.AccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow account %q: %v", follow.AccountID, err)
- }
-
- follow.TargetAccount, err = r.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow target account %q: %v", follow.TargetAccountID, err)
- }
-
- return follow, nil
+func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
}
-func (r *relationshipDB) GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, db.Error) {
- accountIDs := []string{}
-
- // Select only the account ID of each follow.
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- ColumnExpr("? AS ?", bun.Ident("follow.account_id"), bun.Ident("account_id")).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
-
- // Join on accounts table to select only
- // those with NULL domain (local accounts).
- q = q.
- Join("JOIN ? AS ? ON ? = ?",
- bun.Ident("accounts"),
- bun.Ident("account"),
- bun.Ident("follow.account_id"),
- bun.Ident("account.id"),
- ).
- Where("? IS NULL", bun.Ident("account.domain"))
+func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollowers(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
+}
- // We don't *really* need to order these,
- // but it makes it more consistent to do so.
- q = q.Order("account_id DESC")
+func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
+}
- if err := q.Scan(ctx, &accountIDs); err != nil {
+func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
+ var followReqIDs []string
+ if err := newSelectFollowRequests(r.conn, accountID).
+ Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
-
- return accountIDs, nil
+ return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
-func (r *relationshipDB) GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, db.Error) {
- ids := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id").
- Order("follow.updated_at DESC")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
- }
-
- if err := q.Scan(ctx, &ids); err != nil {
+func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
+ var followReqIDs []string
+ if err := newSelectFollowRequesting(r.conn, accountID).
+ Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
-
- follows := make([]*gtsmodel.Follow, 0, len(ids))
- for _, id := range ids {
- follow, err := r.getFollow(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting follow %q: %v", id, err)
- continue
- }
-
- follows = append(follows, follow)
- }
-
- return follows, nil
+ return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
-func (r *relationshipDB) CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
- }
-
- return q.Count(ctx)
+func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
}
-func (r *relationshipDB) getFollowRequest(ctx context.Context, id string) (*gtsmodel.FollowRequest, db.Error) {
- followRequest := &gtsmodel.FollowRequest{}
-
- err := r.conn.
- NewSelect().
- Model(followRequest).
- Where("? = ?", bun.Ident("follow_request.id"), id).
- Scan(ctx)
- if err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- followRequest.Account, err = r.state.DB.GetAccountByID(ctx, followRequest.AccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow request account %q: %v", followRequest.AccountID, err)
- }
-
- followRequest.TargetAccount, err = r.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow request target account %q: %v", followRequest.TargetAccountID, err)
- }
-
- return followRequest, nil
+func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
}
-func (r *relationshipDB) GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, db.Error) {
- ids := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
- }
-
- if err := q.Scan(ctx, &ids); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- followRequests := make([]*gtsmodel.FollowRequest, 0, len(ids))
- for _, id := range ids {
- followRequest, err := r.getFollowRequest(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting follow request %q: %v", id, err)
- continue
- }
-
- followRequests = append(followRequests, followRequest)
- }
-
- return followRequests, nil
+// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
+func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ TableExpr("?", bun.Ident("follow_requests")).
+ ColumnExpr("?", bun.Ident("id")).
+ Where("? = ?", bun.Ident("target_account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
}
-func (r *relationshipDB) CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Order("follow_request.updated_at DESC")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
- }
-
- return q.Count(ctx)
+// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
+func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ TableExpr("?", bun.Ident("follow_requests")).
+ ColumnExpr("?", bun.Ident("id")).
+ Where("? = ?", bun.Ident("target_account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
}
-func (r *relationshipDB) Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
- uri := new(string)
-
- _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID).
- Where("? = ?", bun.Ident("follow.account_id"), originAccountID).
- Returning("?", bun.Ident("uri")).Exec(ctx, uri)
-
- // Only return proper errors.
- if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
- return *uri, err
- }
-
- return *uri, nil
+// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
+func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
}
-func (r *relationshipDB) UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
- uri := new(string)
-
- _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
- Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
- Returning("?", bun.Ident("uri")).Exec(ctx, uri)
-
- // Only return proper errors.
- if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
- return *uri, err
- }
-
- return *uri, nil
+// newSelectLocalFollows returns a new select query for all rows in the follows table with
+// account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
+func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ? AND ? IN (?)",
+ bun.Ident("account_id"),
+ accountID,
+ bun.Ident("target_account_id"),
+ conn.NewSelect().
+ Table("accounts").
+ Column("id").
+ Where("? IS NULL", bun.Ident("domain")),
+ ).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
+}
+
+// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
+func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ?", bun.Ident("target_account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
+}
+
+// newSelectLocalFollowers returns a new select query for all rows in the follows table with
+// target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
+func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ? AND ? IN (?)",
+ bun.Ident("target_account_id"),
+ accountID,
+ bun.Ident("account_id"),
+ conn.NewSelect().
+ Table("accounts").
+ Column("id").
+ Where("? IS NULL", bun.Ident("domain")),
+ ).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
}
diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go
new file mode 100644
index 000000000..9232ea984
--- /dev/null
+++ b/internal/db/bundb/relationship_block.go
@@ -0,0 +1,218 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/uptrace/bun"
+)
+
+func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
+ block, err := r.GetBlock(
+ gtscontext.SetBarebones(ctx),
+ sourceAccountID,
+ targetAccountID,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return false, err
+ }
+ return (block != nil), nil
+}
+
+func (r *relationshipDB) IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error) {
+ // Look for a block in direction of account1->account2
+ b1, err := r.IsBlocked(ctx, accountID1, accountID2)
+ if err != nil || b1 {
+ return true, err
+ }
+
+ // Look for a block in direction of account2->account1
+ b2, err := r.IsBlocked(ctx, accountID2, accountID1)
+ if err != nil || b2 {
+ return true, err
+ }
+
+ return false, nil
+}
+
+func (r *relationshipDB) GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error) {
+ return r.getBlock(
+ ctx,
+ "ID",
+ func(block *gtsmodel.Block) error {
+ return r.conn.NewSelect().Model(block).
+ Where("? = ?", bun.Ident("block.id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error) {
+ return r.getBlock(
+ ctx,
+ "URI",
+ func(block *gtsmodel.Block) error {
+ return r.conn.NewSelect().Model(block).
+ Where("? = ?", bun.Ident("block.uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) {
+ return r.getBlock(
+ ctx,
+ "AccountID.TargetAccountID",
+ func(block *gtsmodel.Block) error {
+ return r.conn.NewSelect().Model(block).
+ Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
+ Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID).
+ Scan(ctx)
+ },
+ sourceAccountID,
+ targetAccountID,
+ )
+}
+
+func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {
+ // Fetch block from cache with loader callback
+ block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) {
+ var block gtsmodel.Block
+
+ // Not cached! Perform database query
+ if err := dbQuery(&block); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ return &block, nil
+ }, keyParts...)
+ if err != nil {
+ // already processe
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // Only a barebones model was requested.
+ return block, nil
+ }
+
+ // Set the block source account
+ block.Account, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ block.AccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting block source account: %w", err)
+ }
+
+ // Set the block target account
+ block.TargetAccount, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ block.TargetAccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting block target account: %w", err)
+ }
+
+ return block, nil
+}
+
+func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
+ err := r.state.Caches.GTS.Block().Store(block, func() error {
+ _, err := r.conn.NewInsert().Model(block).Exec(ctx)
+ return r.conn.ProcessError(err)
+ })
+ if err != nil {
+ return err
+ }
+
+ // Invalidate block origin account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", block.AccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", block.AccountID)
+
+ // Invalidate block target account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", block.TargetAccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", block.TargetAccountID)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
+ block, err := r.GetBlockByID(gtscontext.SetBarebones(ctx), id)
+ if err != nil {
+ return err
+ }
+ return r.deleteBlock(ctx, block)
+}
+
+func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error {
+ block, err := r.GetBlockByURI(gtscontext.SetBarebones(ctx), uri)
+ if err != nil {
+ return err
+ }
+ return r.deleteBlock(ctx, block)
+}
+
+func (r *relationshipDB) deleteBlock(ctx context.Context, block *gtsmodel.Block) error {
+ if _, err := r.conn.
+ NewDelete().
+ Table("blocks").
+ Where("? = ?", bun.Ident("id"), block.ID).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate block from cache lookups.
+ r.state.Caches.GTS.Block().Invalidate("ID", block.ID)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error {
+ var blockIDs []string
+
+ if err := r.conn.NewSelect().
+ Table("blocks").
+ ColumnExpr("?", bun.Ident("id")).
+ WhereOr("? = ? OR ? = ?",
+ bun.Ident("account_id"),
+ accountID,
+ bun.Ident("target_account_id"),
+ accountID,
+ ).
+ Scan(ctx, &blockIDs); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ for _, id := range blockIDs {
+ if err := r.DeleteBlockByID(ctx, id); err != nil {
+ log.Errorf(ctx, "error deleting block %q: %v", id, err)
+ }
+ }
+
+ return nil
+}
diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go
new file mode 100644
index 000000000..4a315d116
--- /dev/null
+++ b/internal/db/bundb/relationship_follow.go
@@ -0,0 +1,243 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/uptrace/bun"
+)
+
+func (r *relationshipDB) GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error) {
+ return r.getFollow(
+ ctx,
+ "ID",
+ func(follow *gtsmodel.Follow) error {
+ return r.conn.NewSelect().
+ Model(follow).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error) {
+ return r.getFollow(
+ ctx,
+ "URI",
+ func(follow *gtsmodel.Follow) error {
+ return r.conn.NewSelect().
+ Model(follow).
+ Where("? = ?", bun.Ident("uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
+ return r.getFollow(
+ ctx,
+ "AccountID.TargetAccountID",
+ func(follow *gtsmodel.Follow) error {
+ return r.conn.NewSelect().
+ Model(follow).
+ Where("? = ?", bun.Ident("account_id"), sourceAccountID).
+ Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
+ Scan(ctx)
+ },
+ sourceAccountID,
+ targetAccountID,
+ )
+}
+
+func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) {
+ // Preallocate slice of expected length.
+ follows := make([]*gtsmodel.Follow, 0, len(ids))
+
+ for _, id := range ids {
+ // Fetch follow model for this ID.
+ follow, err := r.GetFollowByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error getting follow %q: %v", id, err)
+ continue
+ }
+
+ // Append to return slice.
+ follows = append(follows, follow)
+ }
+
+ return follows, nil
+}
+
+func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
+ follow, err := r.GetFollow(
+ gtscontext.SetBarebones(ctx),
+ sourceAccountID,
+ targetAccountID,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return false, err
+ }
+ return (follow != nil), nil
+}
+
+func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, db.Error) {
+ // make sure account 1 follows account 2
+ f1, err := r.IsFollowing(ctx,
+ accountID1,
+ accountID2,
+ )
+ if !f1 /* f1 = false when err != nil */ {
+ return false, err
+ }
+
+ // make sure account 2 follows account 1
+ f2, err := r.IsFollowing(ctx,
+ accountID2,
+ accountID1,
+ )
+ if !f2 /* f2 = false when err != nil */ {
+ return false, err
+ }
+
+ return true, nil
+}
+
+func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) {
+ // Fetch follow from database cache with loader callback
+ follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) {
+ var follow gtsmodel.Follow
+
+ // Not cached! Perform database query
+ if err := dbQuery(&follow); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ return &follow, nil
+ }, keyParts...)
+ if err != nil {
+ // error already processed
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // Only a barebones model was requested.
+ return follow, nil
+ }
+
+ // Set the follow source account
+ follow.Account, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ follow.AccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting follow source account: %w", err)
+ }
+
+ // Set the follow target account
+ follow.TargetAccount, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ follow.TargetAccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting follow target account: %w", err)
+ }
+
+ return follow, nil
+}
+
+func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
+ err := r.state.Caches.GTS.Follow().Store(follow, func() error {
+ _, err := r.conn.NewInsert().Model(follow).Exec(ctx)
+ return r.conn.ProcessError(err)
+ })
+ if err != nil {
+ return err
+ }
+
+ // Invalidate follow origin account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID)
+
+ // Invalidate follow target account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error {
+ if _, err := r.conn.NewDelete().
+ Table("follows").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate follow from cache lookups.
+ r.state.Caches.GTS.Follow().Invalidate("ID", id)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error {
+ if _, err := r.conn.NewDelete().
+ Table("follows").
+ Where("? = ?", bun.Ident("uri"), uri).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate follow from cache lookups.
+ r.state.Caches.GTS.Follow().Invalidate("URI", uri)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error {
+ var followIDs []string
+
+ if _, err := r.conn.
+ NewDelete().
+ Table("follows").
+ WhereOr("? = ? OR ? = ?",
+ bun.Ident("account_id"),
+ accountID,
+ bun.Ident("target_account_id"),
+ accountID,
+ ).
+ Returning("?", bun.Ident("id")).
+ Exec(ctx, &followIDs); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate each returned ID.
+ for _, id := range followIDs {
+ r.state.Caches.GTS.Follow().Invalidate("ID", id)
+ }
+
+ return nil
+}
diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go
new file mode 100644
index 000000000..11200338d
--- /dev/null
+++ b/internal/db/bundb/relationship_follow_req.go
@@ -0,0 +1,293 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/uptrace/bun"
+)
+
+func (r *relationshipDB) GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error) {
+ return r.getFollowRequest(
+ ctx,
+ "ID",
+ func(followReq *gtsmodel.FollowRequest) error {
+ return r.conn.NewSelect().
+ Model(followReq).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error) {
+ return r.getFollowRequest(
+ ctx,
+ "URI",
+ func(followReq *gtsmodel.FollowRequest) error {
+ return r.conn.NewSelect().
+ Model(followReq).
+ Where("? = ?", bun.Ident("uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) {
+ return r.getFollowRequest(
+ ctx,
+ "AccountID.TargetAccountID",
+ func(followReq *gtsmodel.FollowRequest) error {
+ return r.conn.NewSelect().
+ Model(followReq).
+ Where("? = ?", bun.Ident("account_id"), sourceAccountID).
+ Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
+ Scan(ctx)
+ },
+ sourceAccountID,
+ targetAccountID,
+ )
+}
+
+func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) {
+ // Preallocate slice of expected length.
+ followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids))
+
+ for _, id := range ids {
+ // Fetch follow request model for this ID.
+ followReq, err := r.GetFollowRequestByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error getting follow request %q: %v", id, err)
+ continue
+ }
+
+ // Append to return slice.
+ followReqs = append(followReqs, followReq)
+ }
+
+ return followReqs, nil
+}
+
+func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
+ followReq, err := r.GetFollowRequest(
+ gtscontext.SetBarebones(ctx),
+ sourceAccountID,
+ targetAccountID,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return false, err
+ }
+ return (followReq != nil), nil
+}
+
+func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) {
+ // Fetch follow request from database cache with loader callback
+ followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) {
+ var followReq gtsmodel.FollowRequest
+
+ // Not cached! Perform database query
+ if err := dbQuery(&followReq); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ return &followReq, nil
+ }, keyParts...)
+ if err != nil {
+ // error already processed
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // Only a barebones model was requested.
+ return followReq, nil
+ }
+
+ // Set the follow request source account
+ followReq.Account, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ followReq.AccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting follow request source account: %w", err)
+ }
+
+ // Set the follow request target account
+ followReq.TargetAccount, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ followReq.TargetAccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting follow request target account: %w", err)
+ }
+
+ return followReq, nil
+}
+
+func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
+ err := r.state.Caches.GTS.FollowRequest().Store(follow, func() error {
+ _, err := r.conn.NewInsert().Model(follow).Exec(ctx)
+ return r.conn.ProcessError(err)
+ })
+ if err != nil {
+ return err
+ }
+
+ // Invalidate follow request origin account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID)
+
+ // Invalidate follow request target account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
+
+ return nil
+}
+
+func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
+ // Get original follow request.
+ followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create a new follow to 'replace'
+ // the original follow request with.
+ follow := &gtsmodel.Follow{
+ ID: followReq.ID,
+ AccountID: sourceAccountID,
+ Account: followReq.Account,
+ TargetAccountID: targetAccountID,
+ TargetAccount: followReq.TargetAccount,
+ URI: followReq.URI,
+ }
+
+ // If the follow already exists, just
+ // replace the URI with the new one.
+ if _, err := r.conn.
+ NewInsert().
+ Model(follow).
+ On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
+ Exec(ctx); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ // Delete original follow request.
+ if _, err := r.conn.
+ NewDelete().
+ Table("follow_requests").
+ Where("? = ?", bun.Ident("id"), followReq.ID).
+ Exec(ctx); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ // Invalidate follow request from cache lookups.
+ r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID)
+
+ // Delete original follow request notification
+ if err := r.state.DB.DeleteNotifications(ctx, []string{
+ string(gtsmodel.NotificationFollowRequest),
+ }, targetAccountID, sourceAccountID); err != nil {
+ return nil, err
+ }
+
+ return follow, nil
+}
+
+func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) db.Error {
+ // Get original follow request.
+ followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
+ if err != nil {
+ return err
+ }
+
+ // Delete original follow request.
+ if _, err := r.conn.
+ NewDelete().
+ Table("follow_requests").
+ Where("? = ?", bun.Ident("id"), followReq.ID).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Delete original follow request notification
+ return r.state.DB.DeleteNotifications(ctx, []string{
+ string(gtsmodel.NotificationFollowRequest),
+ }, targetAccountID, sourceAccountID)
+}
+
+func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error {
+ if _, err := r.conn.NewDelete().
+ Table("follow_requests").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate follow request from cache lookups.
+ r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error {
+ if _, err := r.conn.NewDelete().
+ Table("follow_requests").
+ Where("? = ?", bun.Ident("uri"), uri).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate follow request from cache lookups.
+ r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error {
+ var followIDs []string
+
+ if _, err := r.conn.
+ NewDelete().
+ Table("follow_requests").
+ WhereOr("? = ? OR ? = ?",
+ bun.Ident("account_id"),
+ accountID,
+ bun.Ident("target_account_id"),
+ accountID,
+ ).
+ Returning("?", bun.Ident("id")).
+ Exec(ctx, &followIDs); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate each returned ID.
+ for _, id := range followIDs {
+ r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
+ }
+
+ return nil
+}
diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go
index 3d307ecde..00583d175 100644
--- a/internal/db/bundb/relationship_test.go
+++ b/internal/db/bundb/relationship_test.go
@@ -19,17 +19,359 @@ package bundb_test
import (
"context"
+ "errors"
+ "reflect"
"testing"
+ "time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
)
type RelationshipTestSuite struct {
BunDBStandardTestSuite
}
+func (suite *RelationshipTestSuite) TestGetBlockBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Sentinel error to mark avoiding a test case.
+ sentinelErr := errors.New("sentinel")
+
+ // isEqual checks if 2 block models are equal.
+ isEqual := func(b1, b2 gtsmodel.Block) bool {
+ // Clear populated sub-models.
+ b1.Account = nil
+ b2.Account = nil
+ b1.TargetAccount = nil
+ b2.TargetAccount = nil
+
+ // Clear database-set fields.
+ b1.CreatedAt = time.Time{}
+ b2.CreatedAt = time.Time{}
+ b1.UpdatedAt = time.Time{}
+ b2.UpdatedAt = time.Time{}
+
+ return reflect.DeepEqual(b1, b2)
+ }
+
+ var testBlocks []*gtsmodel.Block
+
+ for _, account1 := range suite.testAccounts {
+ for _, account2 := range suite.testAccounts {
+ if account1.ID == account2.ID {
+ // don't block *yourself* ...
+ continue
+ }
+
+ // Create new account block.
+ block := &gtsmodel.Block{
+ ID: id.NewULID(),
+ URI: "http://127.0.0.1:8080/" + id.NewULID(),
+ AccountID: account1.ID,
+ TargetAccountID: account2.ID,
+ }
+
+ // Attempt to place the block in database (if not already).
+ if err := suite.db.PutBlock(ctx, block); err != nil {
+ if err != db.ErrAlreadyExists {
+ // Unrecoverable database error.
+ t.Fatalf("error creating block: %v", err)
+ }
+
+ // Fetch existing block from database between accounts.
+ block, _ = suite.db.GetBlock(ctx, account1.ID, account2.ID)
+ continue
+ }
+
+ // Append generated block to test cases.
+ testBlocks = append(testBlocks, block)
+ }
+ }
+
+ for _, block := range testBlocks {
+ for lookup, dbfunc := range map[string]func() (*gtsmodel.Block, error){
+ "id": func() (*gtsmodel.Block, error) {
+ return suite.db.GetBlockByID(ctx, block.ID)
+ },
+
+ "uri": func() (*gtsmodel.Block, error) {
+ return suite.db.GetBlockByURI(ctx, block.URI)
+ },
+
+ "origin_target": func() (*gtsmodel.Block, error) {
+ return suite.db.GetBlock(ctx, block.AccountID, block.TargetAccountID)
+ },
+ } {
+
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ checkBlock, err := dbfunc()
+ if err != nil {
+ if err == sentinelErr {
+ continue
+ }
+
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Check received block data.
+ if !isEqual(*checkBlock, *block) {
+ t.Errorf("block does not contain expected data: %+v", checkBlock)
+ continue
+ }
+
+ // Check that block origin account populated.
+ if checkBlock.Account == nil || checkBlock.Account.ID != block.AccountID {
+ t.Errorf("block origin account not correctly populated for: %+v", checkBlock)
+ continue
+ }
+
+ // Check that block target account populated.
+ if checkBlock.TargetAccount == nil || checkBlock.TargetAccount.ID != block.TargetAccountID {
+ t.Errorf("block target account not correctly populated for: %+v", checkBlock)
+ continue
+ }
+ }
+ }
+}
+
+func (suite *RelationshipTestSuite) TestGetFollowBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Sentinel error to mark avoiding a test case.
+ sentinelErr := errors.New("sentinel")
+
+ // isEqual checks if 2 follow models are equal.
+ isEqual := func(f1, f2 gtsmodel.Follow) bool {
+ // Clear populated sub-models.
+ f1.Account = nil
+ f2.Account = nil
+ f1.TargetAccount = nil
+ f2.TargetAccount = nil
+
+ // Clear database-set fields.
+ f1.CreatedAt = time.Time{}
+ f2.CreatedAt = time.Time{}
+ f1.UpdatedAt = time.Time{}
+ f2.UpdatedAt = time.Time{}
+
+ return reflect.DeepEqual(f1, f2)
+ }
+
+ var testFollows []*gtsmodel.Follow
+
+ for _, account1 := range suite.testAccounts {
+ for _, account2 := range suite.testAccounts {
+ if account1.ID == account2.ID {
+ // don't follow *yourself* ...
+ continue
+ }
+
+ // Create new account follow.
+ follow := &gtsmodel.Follow{
+ ID: id.NewULID(),
+ URI: "http://127.0.0.1:8080/" + id.NewULID(),
+ AccountID: account1.ID,
+ TargetAccountID: account2.ID,
+ }
+
+ // Attempt to place the follow in database (if not already).
+ if err := suite.db.PutFollow(ctx, follow); err != nil {
+ if err != db.ErrAlreadyExists {
+ // Unrecoverable database error.
+ t.Fatalf("error creating follow: %v", err)
+ }
+
+ // Fetch existing follow from database between accounts.
+ follow, _ = suite.db.GetFollow(ctx, account1.ID, account2.ID)
+ continue
+ }
+
+ // Append generated follow to test cases.
+ testFollows = append(testFollows, follow)
+ }
+ }
+
+ for _, follow := range testFollows {
+ for lookup, dbfunc := range map[string]func() (*gtsmodel.Follow, error){
+ "id": func() (*gtsmodel.Follow, error) {
+ return suite.db.GetFollowByID(ctx, follow.ID)
+ },
+
+ "uri": func() (*gtsmodel.Follow, error) {
+ return suite.db.GetFollowByURI(ctx, follow.URI)
+ },
+
+ "origin_target": func() (*gtsmodel.Follow, error) {
+ return suite.db.GetFollow(ctx, follow.AccountID, follow.TargetAccountID)
+ },
+ } {
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ checkFollow, err := dbfunc()
+ if err != nil {
+ if err == sentinelErr {
+ continue
+ }
+
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Check received follow data.
+ if !isEqual(*checkFollow, *follow) {
+ t.Errorf("follow does not contain expected data: %+v", checkFollow)
+ continue
+ }
+
+ // Check that follow origin account populated.
+ if checkFollow.Account == nil || checkFollow.Account.ID != follow.AccountID {
+ t.Errorf("follow origin account not correctly populated for: %+v", checkFollow)
+ continue
+ }
+
+ // Check that follow target account populated.
+ if checkFollow.TargetAccount == nil || checkFollow.TargetAccount.ID != follow.TargetAccountID {
+ t.Errorf("follow target account not correctly populated for: %+v", checkFollow)
+ continue
+ }
+ }
+ }
+}
+
+func (suite *RelationshipTestSuite) TestGetFollowRequestBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Sentinel error to mark avoiding a test case.
+ sentinelErr := errors.New("sentinel")
+
+ // isEqual checks if 2 follow request models are equal.
+ isEqual := func(f1, f2 gtsmodel.FollowRequest) bool {
+ // Clear populated sub-models.
+ f1.Account = nil
+ f2.Account = nil
+ f1.TargetAccount = nil
+ f2.TargetAccount = nil
+
+ // Clear database-set fields.
+ f1.CreatedAt = time.Time{}
+ f2.CreatedAt = time.Time{}
+ f1.UpdatedAt = time.Time{}
+ f2.UpdatedAt = time.Time{}
+
+ return reflect.DeepEqual(f1, f2)
+ }
+
+ var testFollowReqs []*gtsmodel.FollowRequest
+
+ for _, account1 := range suite.testAccounts {
+ for _, account2 := range suite.testAccounts {
+ if account1.ID == account2.ID {
+ // don't follow *yourself* ...
+ continue
+ }
+
+ // Create new account follow request.
+ followReq := &gtsmodel.FollowRequest{
+ ID: id.NewULID(),
+ URI: "http://127.0.0.1:8080/" + id.NewULID(),
+ AccountID: account1.ID,
+ TargetAccountID: account2.ID,
+ }
+
+ // Attempt to place the follow in database (if not already).
+ if err := suite.db.PutFollowRequest(ctx, followReq); err != nil {
+ if err != db.ErrAlreadyExists {
+ // Unrecoverable database error.
+ t.Fatalf("error creating follow request: %v", err)
+ }
+
+ // Fetch existing follow request from database between accounts.
+ followReq, _ = suite.db.GetFollowRequest(ctx, account1.ID, account2.ID)
+ continue
+ }
+
+ // Append generated follow request to test cases.
+ testFollowReqs = append(testFollowReqs, followReq)
+ }
+ }
+
+ for _, followReq := range testFollowReqs {
+ for lookup, dbfunc := range map[string]func() (*gtsmodel.FollowRequest, error){
+ "id": func() (*gtsmodel.FollowRequest, error) {
+ return suite.db.GetFollowRequestByID(ctx, followReq.ID)
+ },
+
+ "uri": func() (*gtsmodel.FollowRequest, error) {
+ return suite.db.GetFollowRequestByURI(ctx, followReq.URI)
+ },
+
+ "origin_target": func() (*gtsmodel.FollowRequest, error) {
+ return suite.db.GetFollowRequest(ctx, followReq.AccountID, followReq.TargetAccountID)
+ },
+ } {
+
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ checkFollowReq, err := dbfunc()
+ if err != nil {
+ if err == sentinelErr {
+ continue
+ }
+
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Check received follow request data.
+ if !isEqual(*checkFollowReq, *followReq) {
+ t.Errorf("follow request does not contain expected data: %+v", checkFollowReq)
+ continue
+ }
+
+ // Check that follow request origin account populated.
+ if checkFollowReq.Account == nil || checkFollowReq.Account.ID != followReq.AccountID {
+ t.Errorf("follow request origin account not correctly populated for: %+v", checkFollowReq)
+ continue
+ }
+
+ // Check that follow request target account populated.
+ if checkFollowReq.TargetAccount == nil || checkFollowReq.TargetAccount.ID != followReq.TargetAccountID {
+ t.Errorf("follow request target account not correctly populated for: %+v", checkFollowReq)
+ continue
+ }
+ }
+ }
+}
+
func (suite *RelationshipTestSuite) TestIsBlocked() {
ctx := context.Background()
@@ -37,11 +379,11 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
account2 := suite.testAccounts["local_account_2"].ID
// no blocks exist between account 1 and account 2
- blocked, err := suite.db.IsBlocked(ctx, account1, account2, false)
+ blocked, err := suite.db.IsBlocked(ctx, account1, account2)
suite.NoError(err)
suite.False(blocked)
- blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
+ blocked, err = suite.db.IsBlocked(ctx, account2, account1)
suite.NoError(err)
suite.False(blocked)
@@ -56,45 +398,24 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
}
// account 1 now blocks account 2
- blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
+ blocked, err = suite.db.IsBlocked(ctx, account1, account2)
suite.NoError(err)
suite.True(blocked)
// account 2 doesn't block account 1
- blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
+ blocked, err = suite.db.IsBlocked(ctx, account2, account1)
suite.NoError(err)
suite.False(blocked)
// a block exists in either direction between the two
- blocked, err = suite.db.IsBlocked(ctx, account1, account2, true)
+ blocked, err = suite.db.IsEitherBlocked(ctx, account1, account2)
suite.NoError(err)
suite.True(blocked)
- blocked, err = suite.db.IsBlocked(ctx, account2, account1, true)
+ blocked, err = suite.db.IsEitherBlocked(ctx, account2, account1)
suite.NoError(err)
suite.True(blocked)
}
-func (suite *RelationshipTestSuite) TestGetBlock() {
- ctx := context.Background()
-
- account1 := suite.testAccounts["local_account_1"].ID
- account2 := suite.testAccounts["local_account_2"].ID
-
- if err := suite.db.PutBlock(ctx, &gtsmodel.Block{
- ID: "01G202BCSXXJZ70BHB5KCAHH8C",
- URI: "http://localhost:8080/some_block_uri_1",
- AccountID: account1,
- TargetAccountID: account2,
- }); err != nil {
- suite.FailNow(err.Error())
- }
-
- block, err := suite.db.GetBlock(ctx, account1, account2)
- suite.NoError(err)
- suite.NotNil(block)
- suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
-}
-
func (suite *RelationshipTestSuite) TestDeleteBlockByID() {
ctx := context.Background()
@@ -157,7 +478,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlockByURI() {
suite.Nil(block)
}
-func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() {
+func (suite *RelationshipTestSuite) TestDeleteAccountBlocks() {
ctx := context.Background()
// put a block in first
@@ -179,38 +500,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() {
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
// delete the block by originAccountID
- err = suite.db.DeleteBlocksByOriginAccountID(ctx, account1)
- suite.NoError(err)
-
- // block should be gone
- block, err = suite.db.GetBlock(ctx, account1, account2)
- suite.ErrorIs(err, db.ErrNoEntries)
- suite.Nil(block)
-}
-
-func (suite *RelationshipTestSuite) TestDeleteBlocksByTargetAccountID() {
- ctx := context.Background()
-
- // put a block in first
- account1 := suite.testAccounts["local_account_1"].ID
- account2 := suite.testAccounts["local_account_2"].ID
- if err := suite.db.PutBlock(ctx, &gtsmodel.Block{
- ID: "01G202BCSXXJZ70BHB5KCAHH8C",
- URI: "http://localhost:8080/some_block_uri_1",
- AccountID: account1,
- TargetAccountID: account2,
- }); err != nil {
- suite.FailNow(err.Error())
- }
-
- // make sure the block is in the db
- block, err := suite.db.GetBlock(ctx, account1, account2)
- suite.NoError(err)
- suite.NotNil(block)
- suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
-
- // delete the block by targetAccountID
- err = suite.db.DeleteBlocksByTargetAccountID(ctx, account2)
+ err = suite.db.DeleteAccountBlocks(ctx, account1)
suite.NoError(err)
// block should be gone
@@ -244,7 +534,7 @@ func (suite *RelationshipTestSuite) TestGetRelationship() {
func (suite *RelationshipTestSuite) TestIsFollowingYes() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
- isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
+ isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isFollowing)
}
@@ -252,7 +542,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingYes() {
func (suite *RelationshipTestSuite) TestIsFollowingNo() {
requestingAccount := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
- isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
+ isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.False(isFollowing)
}
@@ -260,7 +550,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingNo() {
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
- isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
+ isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isMutualFollowing)
}
@@ -268,7 +558,7 @@ func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["local_account_2"]
- isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
+ isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isMutualFollowing)
}
@@ -306,7 +596,7 @@ func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() {
suite.Equal(followRequest.URI, follow.URI)
// Ensure notification is deleted.
- notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID)
+ notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(notification)
}
@@ -389,7 +679,7 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
TargetAccountID: targetAccount.ID,
}
- if err := suite.db.Put(ctx, followRequest); err != nil {
+ if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil {
suite.FailNow(err.Error())
}
@@ -404,12 +694,11 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
suite.FailNow(err.Error())
}
- rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
+ err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
suite.NoError(err)
- suite.NotNil(rejectedFollowRequest)
// Ensure notification is deleted.
- notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID)
+ notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(notification)
}
@@ -419,9 +708,8 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() {
account := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
- rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
+ err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
suite.ErrorIs(err, db.ErrNoEntries)
- suite.Nil(rejectedFollowRequest)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
@@ -440,42 +728,49 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
suite.FailNow(err.Error())
}
- followRequests, err := suite.db.GetFollowRequests(ctx, "", targetAccount.ID)
+ followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID)
suite.NoError(err)
suite.Len(followRequests, 1)
}
func (suite *RelationshipTestSuite) TestGetAccountFollows() {
account := suite.testAccounts["local_account_1"]
- follows, err := suite.db.GetFollows(context.Background(), account.ID, "")
+ follows, err := suite.db.GetAccountFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Len(follows, 2)
}
+func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
+ account := suite.testAccounts["local_account_1"]
+ followsCount, err := suite.db.CountAccountLocalFollows(context.Background(), account.ID)
+ suite.NoError(err)
+ suite.Equal(2, followsCount)
+}
+
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
account := suite.testAccounts["local_account_1"]
- followsCount, err := suite.db.CountFollows(context.Background(), account.ID, "")
+ followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
-func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() {
+func (suite *RelationshipTestSuite) TestGetAccountFollowers() {
account := suite.testAccounts["local_account_1"]
- follows, err := suite.db.GetFollows(context.Background(), "", account.ID)
+ follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.Len(follows, 2)
}
-func (suite *RelationshipTestSuite) TestGetLocalFollowersIDs() {
+func (suite *RelationshipTestSuite) TestCountAccountFollowers() {
account := suite.testAccounts["local_account_1"]
- accountIDs, err := suite.db.GetLocalFollowersIDs(context.Background(), account.ID)
+ followsCount, err := suite.db.CountAccountFollowers(context.Background(), account.ID)
suite.NoError(err)
- suite.EqualValues([]string{"01F8MH5NBDF2MV7CTC4Q5128HF", "01F8MH17FWEB39HZJ76B6VXSKF"}, accountIDs)
+ suite.Equal(2, followsCount)
}
-func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() {
+func (suite *RelationshipTestSuite) TestCountAccountFollowersLocalOnly() {
account := suite.testAccounts["local_account_1"]
- followsCount, err := suite.db.CountFollows(context.Background(), "", account.ID)
+ followsCount, err := suite.db.CountAccountLocalFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
@@ -484,18 +779,25 @@ func (suite *RelationshipTestSuite) TestUnfollowExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
- uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccount.ID)
+ follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.NoError(err)
- suite.Equal("http://localhost:8080/users/the_mighty_zork/follow/01F8PY8RHWRQZV038T4E8T9YK8", uri)
+ suite.NotNil(follow)
+
+ err = suite.db.DeleteFollowByID(context.Background(), follow.ID)
+ suite.NoError(err)
+
+ follow, err = suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
+ suite.EqualError(err, db.ErrNoEntries.Error())
+ suite.Nil(follow)
}
func (suite *RelationshipTestSuite) TestUnfollowNotExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"
- uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccountID)
- suite.NoError(err)
- suite.Empty(uri)
+ follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccountID)
+ suite.EqualError(err, db.ErrNoEntries.Error())
+ suite.Nil(follow)
}
func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() {
@@ -510,22 +812,29 @@ func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() {
TargetAccountID: targetAccount.ID,
}
- if err := suite.db.Put(ctx, followRequest); err != nil {
+ if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil {
suite.FailNow(err.Error())
}
- uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
+ followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
+ suite.NoError(err)
+ suite.NotNil(followRequest)
+
+ err = suite.db.DeleteFollowRequestByID(context.Background(), followRequest.ID)
suite.NoError(err)
- suite.Equal("http://localhost:8080/weeeeeeeeeeeeeeeee", uri)
+
+ followRequest, err = suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
+ suite.EqualError(err, db.ErrNoEntries.Error())
+ suite.Nil(followRequest)
}
func (suite *RelationshipTestSuite) TestUnfollowRequestNotExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"
- uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccountID)
- suite.NoError(err)
- suite.Empty(uri)
+ followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccountID)
+ suite.EqualError(err, db.ErrNoEntries.Error())
+ suite.Nil(followRequest)
}
func TestRelationshipTestSuite(t *testing.T) {
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index deec9a118..c2b5546f8 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -26,6 +26,7 @@ import (
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@@ -41,7 +42,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
Model(status).
- Relation("Attachments").
Relation("Tags").
Relation("CreatedWithApplication")
}
@@ -102,79 +102,141 @@ func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*g
status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) {
var status gtsmodel.Status
- // Not cached! Perform database query
+ // Not cached! Perform database query.
if err := dbQuery(&status); err != nil {
return nil, s.conn.ProcessError(err)
}
- if status.InReplyToID != "" {
- // Also load in-reply-to status
- status.InReplyTo = new(gtsmodel.Status)
- err := s.conn.NewSelect().Model(status.InReplyTo).
- Where("? = ?", bun.Ident("status.id"), status.InReplyToID).
- Scan(ctx)
+ return &status, nil
+ }, keyParts...)
+ if err != nil {
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return status, nil
+ }
+
+ // Further populate the status fields where applicable.
+ if err := s.PopulateStatus(ctx, status); err != nil {
+ return nil, err
+ }
+
+ return status, nil
+}
+
+func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) error {
+ var err error
+
+ if status.Account == nil {
+ // Status author is not set, fetch from database.
+ status.Account, err = s.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ status.AccountID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status author: %w", err)
+ }
+ }
+
+ if status.InReplyToID != "" && status.InReplyTo == nil {
+ // Status parent is not set, fetch from database.
+ status.InReplyTo, err = s.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ status.InReplyToID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status parent: %w", err)
+ }
+ }
+
+ if status.InReplyToID != "" {
+ if status.InReplyTo == nil {
+ // Status parent is not set, fetch from database.
+ status.InReplyTo, err = s.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ status.InReplyToID,
+ )
if err != nil {
- return nil, s.conn.ProcessError(err)
+ return fmt.Errorf("error populating status parent: %w", err)
}
}
- if status.BoostOfID != "" {
- // Also load original boosted status
- status.BoostOf = new(gtsmodel.Status)
- err := s.conn.NewSelect().Model(status.BoostOf).
- Where("? = ?", bun.Ident("status.id"), status.BoostOfID).
- Scan(ctx)
+ if status.InReplyToAccount == nil {
+ // Status parent author is not set, fetch from database.
+ status.InReplyToAccount, err = s.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ status.InReplyToAccountID,
+ )
if err != nil {
- return nil, s.conn.ProcessError(err)
+ return fmt.Errorf("error populating status parent author: %w", err)
}
}
-
- return &status, nil
- }, keyParts...)
- if err != nil {
- // error already processed
- return nil, err
}
- // Set the status author account
- status.Account, err = s.state.DB.GetAccountByID(ctx, status.AccountID)
- if err != nil {
- return nil, fmt.Errorf("error getting status account: %w", err)
- }
+ if status.BoostOfID != "" {
+ if status.BoostOf == nil {
+ // Status boost is not set, fetch from database.
+ status.BoostOf, err = s.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ status.BoostOfID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status boost: %w", err)
+ }
+ }
- if id := status.BoostOfAccountID; id != "" {
- // Set boost of status' author account
- status.BoostOfAccount, err = s.state.DB.GetAccountByID(ctx, id)
- if err != nil {
- return nil, fmt.Errorf("error getting boosted status account: %w", err)
+ if status.BoostOfAccount == nil {
+ // Status boost author is not set, fetch from database.
+ status.BoostOfAccount, err = s.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ status.BoostOfAccountID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status boost author: %w", err)
+ }
}
}
- if id := status.InReplyToAccountID; id != "" {
- // Set in-reply-to status' author account
- status.InReplyToAccount, err = s.state.DB.GetAccountByID(ctx, id)
+ if !status.AttachmentsPopulated() {
+ // Status attachments are out-of-date with IDs, repopulate.
+ status.Attachments, err = s.state.DB.GetAttachmentsByIDs(
+ ctx, // these are already barebones
+ status.AttachmentIDs,
+ )
if err != nil {
- return nil, fmt.Errorf("error getting in reply to status account: %w", err)
+ return fmt.Errorf("error populating status attachments: %w", err)
}
}
- if len(status.EmojiIDs) > 0 {
- // Fetch status emojis
- status.Emojis, err = s.state.DB.GetEmojisByIDs(ctx, status.EmojiIDs)
+ // TODO: once we don't fetch using relations.
+ // if !status.TagsPopulated() {
+ // }
+
+ if !status.MentionsPopulated() {
+ // Status mentions are out-of-date with IDs, repopulate.
+ status.Mentions, err = s.state.DB.GetMentions(
+ ctx, // leave fully populated for now
+ status.MentionIDs,
+ )
if err != nil {
- return nil, fmt.Errorf("error getting status emojis: %w", err)
+ return fmt.Errorf("error populating status mentions: %w", err)
}
}
- if len(status.MentionIDs) > 0 {
- // Fetch status mentions
- status.Mentions, err = s.state.DB.GetMentions(ctx, status.MentionIDs)
+ if !status.EmojisPopulated() {
+ // Status emojis are out-of-date with IDs, repopulate.
+ status.Emojis, err = s.state.DB.GetEmojisByIDs(
+ ctx, // these are already barebones
+ status.EmojiIDs,
+ )
if err != nil {
- return nil, fmt.Errorf("error getting status mentions: %w", err)
+ return fmt.Errorf("error populating status emojis: %w", err)
}
}
- return status, nil
+ return nil
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
@@ -239,12 +301,16 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
})
})
if err != nil {
- // already processed
return err
}
for _, id := range status.AttachmentIDs {
- // Clear updated media attachment IDs from cache
+ // Invalidate media attachments from cache.
+ //
+ // NOTE: this is needed due to the way in which
+ // we upload status attachments, and only after
+ // update them with a known status ID. This is
+ // not the case for header/avatar attachments.
s.state.Caches.GTS.Media().Invalidate("ID", id)
}
@@ -322,14 +388,19 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
return err
}
+ // Invalidate status from database lookups.
+ s.state.Caches.GTS.Status().Invalidate("ID", status.ID)
+
for _, id := range status.AttachmentIDs {
- // Clear updated media attachment IDs from cache
+ // Invalidate media attachments from cache.
+ //
+ // NOTE: this is needed due to the way in which
+ // we upload status attachments, and only after
+ // update them with a known status ID. This is
+ // not the case for header/avatar attachments.
s.state.Caches.GTS.Media().Invalidate("ID", id)
}
- // Drop any old status value from cache by this ID
- s.state.Caches.GTS.Status().Invalidate("ID", status.ID)
-
return nil
}
@@ -367,8 +438,12 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
return err
}
- // Drop any old value from cache by this ID
+ // Invalidate status from database lookups.
s.state.Caches.GTS.Status().Invalidate("ID", id)
+
+ // Invalidate status from all visibility lookups.
+ s.state.Caches.Visibility.Invalidate("ItemID", id)
+
return nil
}
diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go
index c42ab249f..0f7e5df74 100644
--- a/internal/db/bundb/statusfave.go
+++ b/internal/db/bundb/statusfave.go
@@ -23,6 +23,7 @@ import (
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@@ -34,29 +35,82 @@ type statusFaveDB struct {
state *state.State
}
-func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) {
- fave := new(gtsmodel.StatusFave)
+func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) {
+ return s.getStatusFave(
+ ctx,
+ "AccountID.StatusID",
+ func(fave *gtsmodel.StatusFave) error {
+ return s.conn.
+ NewSelect().
+ Model(fave).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Where("? = ?", bun.Ident("status_id"), statusID).
+ Scan(ctx)
+ },
+ accountID,
+ statusID,
+ )
+}
- err := s.conn.
- NewSelect().
- Model(fave).
- Where("? = ?", bun.Ident("status_fave.ID"), id).
- Scan(ctx)
+func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) {
+ return s.getStatusFave(
+ ctx,
+ "ID",
+ func(fave *gtsmodel.StatusFave) error {
+ return s.conn.
+ NewSelect().
+ Model(fave).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) {
+ // Fetch status fave from database cache with loader callback
+ fave, err := s.state.Caches.GTS.StatusFave().Load(lookup, func() (*gtsmodel.StatusFave, error) {
+ var fave gtsmodel.StatusFave
+
+ // Not cached! Perform database query.
+ if err := dbQuery(&fave); err != nil {
+ return nil, s.conn.ProcessError(err)
+ }
+
+ return &fave, nil
+ }, keyParts...)
if err != nil {
- return nil, s.conn.ProcessError(err)
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return fave, nil
}
- fave.Account, err = s.state.DB.GetAccountByID(ctx, fave.AccountID)
+ // Fetch the status fave author account.
+ fave.Account, err = s.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ fave.AccountID,
+ )
if err != nil {
return nil, fmt.Errorf("error getting status fave account %q: %w", fave.AccountID, err)
}
- fave.TargetAccount, err = s.state.DB.GetAccountByID(ctx, fave.TargetAccountID)
+ // Fetch the status fave target account.
+ fave.TargetAccount, err = s.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ fave.TargetAccountID,
+ )
if err != nil {
return nil, fmt.Errorf("error getting status fave target account %q: %w", fave.TargetAccountID, err)
}
- fave.Status, err = s.state.DB.GetStatusByID(ctx, fave.StatusID)
+ // Fetch the status fave target status.
+ fave.Status, err = s.state.DB.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ fave.StatusID,
+ )
if err != nil {
return nil, fmt.Errorf("error getting status fave status %q: %w", fave.StatusID, err)
}
@@ -64,38 +118,22 @@ func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel.
return fave, nil
}
-func (s *statusFaveDB) GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) {
- var id string
-
- err := s.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
- Column("status_fave.id").
- Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
- Where("? = ?", bun.Ident("status_fave.status_id"), statusID).
- Scan(ctx, &id)
- if err != nil {
- return nil, s.conn.ProcessError(err)
- }
-
- return s.GetStatusFave(ctx, id)
-}
-
-func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) {
+func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) {
ids := []string{}
if err := s.conn.
NewSelect().
- TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
- Column("status_fave.id").
- Where("? = ?", bun.Ident("status_fave.status_id"), statusID).
+ Table("status_faves").
+ Column("id").
+ Where("? = ?", bun.Ident("status_id"), statusID).
Scan(ctx, &ids); err != nil {
return nil, s.conn.ProcessError(err)
}
faves := make([]*gtsmodel.StatusFave, 0, len(ids))
+
for _, id := range ids {
- fave, err := s.GetStatusFave(ctx, id)
+ fave, err := s.GetStatusFaveByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting status fave %q: %v", id, err)
continue
@@ -107,23 +145,27 @@ func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*
return faves, nil
}
-func (s *statusFaveDB) PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) db.Error {
- _, err := s.conn.
- NewInsert().
- Model(statusFave).
- Exec(ctx)
-
- return s.conn.ProcessError(err)
+func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) db.Error {
+ return s.state.Caches.GTS.StatusFave().Store(fave, func() error {
+ _, err := s.conn.
+ NewInsert().
+ Model(fave).
+ Exec(ctx)
+ return s.conn.ProcessError(err)
+ })
}
-func (s *statusFaveDB) DeleteStatusFave(ctx context.Context, id string) db.Error {
- _, err := s.conn.
+func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.Error {
+ if _, err := s.conn.
NewDelete().
- TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
- Where("? = ?", bun.Ident("status_fave.id"), id).
- Exec(ctx)
+ Table("status_faves").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return s.conn.ProcessError(err)
+ }
- return s.conn.ProcessError(err)
+ s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
+ return nil
}
func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
@@ -131,42 +173,52 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set")
}
- // TODO: Capture fave IDs in a RETURNING
- // statement (when faves have a cache),
- // + use the IDs to invalidate cache entries.
+ // Capture fave IDs in a RETURNING statement.
+ var faveIDs []string
q := s.conn.
NewDelete().
- TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave"))
+ Table("status_faves").
+ Returning("?", bun.Ident("id"))
if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("status_fave.target_account_id"), targetAccountID)
+ q = q.Where("? = ?", bun.Ident("target_account_id"), targetAccountID)
}
if originAccountID != "" {
- q = q.Where("? = ?", bun.Ident("status_fave.account_id"), originAccountID)
+ q = q.Where("? = ?", bun.Ident("account_id"), originAccountID)
}
- if _, err := q.Exec(ctx); err != nil {
+ if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err)
}
+ for _, id := range faveIDs {
+ // Invalidate each of the returned status fave IDs.
+ s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
+ }
+
return nil
}
func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) db.Error {
- // TODO: Capture fave IDs in a RETURNING
- // statement (when faves have a cache),
- // + use the IDs to invalidate cache entries.
+ // Capture fave IDs in a RETURNING statement.
+ var faveIDs []string
q := s.conn.
NewDelete().
- TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
- Where("? = ?", bun.Ident("status_fave.status_id"), statusID)
+ Table("status_faves").
+ Where("? = ?", bun.Ident("status_id"), statusID).
+ Returning("?", bun.Ident("id"))
- if _, err := q.Exec(ctx); err != nil {
+ if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err)
}
+ for _, id := range faveIDs {
+ // Invalidate each of the returned status fave IDs.
+ s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
+ }
+
return nil
}
diff --git a/internal/db/bundb/statusfave_test.go b/internal/db/bundb/statusfave_test.go
index 98e495bf3..7218390bc 100644
--- a/internal/db/bundb/statusfave_test.go
+++ b/internal/db/bundb/statusfave_test.go
@@ -35,7 +35,7 @@ type StatusFaveTestSuite struct {
func (suite *StatusFaveTestSuite) TestGetStatusFaves() {
testStatus := suite.testStatuses["admin_account_status_1"]
- faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID)
+ faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -51,7 +51,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaves() {
func (suite *StatusFaveTestSuite) TestGetStatusFavesNone() {
testStatus := suite.testStatuses["admin_account_status_4"]
- faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID)
+ faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -63,7 +63,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaveByAccountID() {
testAccount := suite.testAccounts["local_account_1"]
testStatus := suite.testStatuses["admin_account_status_1"]
- fave, err := suite.db.GetStatusFaveByAccountID(context.Background(), testAccount.ID, testStatus.ID)
+ fave, err := suite.db.GetStatusFave(context.Background(), testAccount.ID, testStatus.ID)
suite.NoError(err)
suite.NotNil(fave)
}
@@ -129,17 +129,17 @@ func (suite *StatusFaveTestSuite) TestDeleteStatusFave() {
testFave := suite.testFaves["local_account_1_admin_account_status_1"]
ctx := context.Background()
- if err := suite.db.DeleteStatusFave(ctx, testFave.ID); err != nil {
+ if err := suite.db.DeleteStatusFaveByID(ctx, testFave.ID); err != nil {
suite.FailNow(err.Error())
}
- fave, err := suite.db.GetStatusFave(ctx, testFave.ID)
+ fave, err := suite.db.GetStatusFaveByID(ctx, testFave.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(fave)
}
func (suite *StatusFaveTestSuite) TestDeleteStatusFaveNonExisting() {
- err := suite.db.DeleteStatusFave(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G")
+ err := suite.db.DeleteStatusFaveByID(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G")
suite.NoError(err)
}
diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go
index ea4a87d03..1ab140103 100644
--- a/internal/db/bundb/timeline.go
+++ b/internal/db/bundb/timeline.go
@@ -61,9 +61,12 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
Order("status.id DESC")
if maxID == "" {
+ const future = 24 * time.Hour
+
var err error
- // don't return statuses more than five minutes in the future
- maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
+
+ // don't return statuses more than 24hr in the future
+ maxID, err = id.NewULIDFromTime(time.Now().Add(future))
if err != nil {
return nil, err
}
@@ -138,15 +141,16 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.id").
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
- WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_id")).
- WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
Order("status.id DESC")
if maxID == "" {
+ const future = 24 * time.Hour
+
var err error
- // don't return statuses more than five minutes in the future
- maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
+
+ // don't return statuses more than 24hr in the future
+ maxID, err = id.NewULIDFromTime(time.Now().Add(future))
if err != nil {
return nil, err
}
diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go
index 5a447111c..d6632b38c 100644
--- a/internal/db/bundb/timeline_test.go
+++ b/internal/db/bundb/timeline_test.go
@@ -34,15 +34,32 @@ type TimelineTestSuite struct {
}
func (suite *TimelineTestSuite) TestGetPublicTimeline() {
- ctx := context.Background()
+ var count int
+
+ for _, status := range suite.testStatuses {
+ if status.Visibility == gtsmodel.VisibilityPublic &&
+ status.BoostOfID == "" {
+ count++
+ }
+ }
+ ctx := context.Background()
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err)
- suite.Len(s, 6)
+ suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
+ var count int
+
+ for _, status := range suite.testStatuses {
+ if status.Visibility == gtsmodel.VisibilityPublic &&
+ status.BoostOfID == "" {
+ count++
+ }
+ }
+
ctx := context.Background()
futureStatus := getFutureStatus()
@@ -53,7 +70,7 @@ func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
suite.NoError(err)
suite.NotContains(s, futureStatus)
- suite.Len(s, 6)
+ suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetHomeTimeline() {
diff --git a/internal/db/media.go b/internal/db/media.go
index d86f9fe84..05609ba52 100644
--- a/internal/db/media.go
+++ b/internal/db/media.go
@@ -29,6 +29,9 @@ type Media interface {
// GetAttachmentByID gets a single attachment by its ID.
GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error)
+ // GetAttachmentsByIDs fetches a list of media attachments for given IDs.
+ GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error)
+
// PutAttachment inserts the given attachment into the database.
PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error
diff --git a/internal/db/mention.go b/internal/db/mention.go
index d66394a5d..348f946a2 100644
--- a/internal/db/mention.go
+++ b/internal/db/mention.go
@@ -30,4 +30,10 @@ type Mention interface {
// GetMentions gets multiple mentions.
GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error)
+
+ // PutMention will insert the given mention into the database.
+ PutMention(ctx context.Context, mention *gtsmodel.Mention) error
+
+ // DeleteMentionByID will delete mention with given ID from the database.
+ DeleteMentionByID(ctx context.Context, id string) error
}
diff --git a/internal/db/notification.go b/internal/db/notification.go
index 18e40b4c1..fd3affe90 100644
--- a/internal/db/notification.go
+++ b/internal/db/notification.go
@@ -28,14 +28,17 @@ type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
- GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
+ GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
// GetNotification returns one notification according to its id.
- GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error)
+ GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, Error)
- // DeleteNotification deletes one notification according to its id,
+ // PutNotification will insert the given notification into the database.
+ PutNotification(ctx context.Context, notif *gtsmodel.Notification) error
+
+ // DeleteNotificationByID deletes one notification according to its id,
// and removes that notification from the in-memory cache.
- DeleteNotification(ctx context.Context, id string) Error
+ DeleteNotificationByID(ctx context.Context, id string) Error
// DeleteNotifications mass deletes notifications targeting targetAccountID
// and/or originating from originAccountID.
@@ -50,7 +53,7 @@ type Notification interface {
// originate from originAccountID will be deleted.
//
// At least one parameter must not be an empty string.
- DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) Error
+ DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) Error
// DeleteNotificationsForStatus deletes all notifications that relate to
// the given statusID. This function is useful when a status has been deleted,
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index d13a73dea..838647154 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -25,42 +25,86 @@ import (
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
- // IsBlocked checks whether account 1 has a block in place against account2.
- // If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
- IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error)
+ // IsBlocked checks whether source account has a block in place against target.
+ IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
+
+ // IsEitherBlocked checks whether there is a block in place between either of account1 and account2.
+ IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error)
+
+ // GetBlockByID fetches block with given ID from the database.
+ GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error)
+
+ // GetBlockByURI fetches block with given AP URI from the database.
+ GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
- //
- // Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
- // not if you're just checking for the existence of a block.
- GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error)
+ GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, error)
// PutBlock attempts to place the given account block in the database.
- PutBlock(ctx context.Context, block *gtsmodel.Block) Error
+ PutBlock(ctx context.Context, block *gtsmodel.Block) error
// DeleteBlockByID removes block with given ID from the database.
- DeleteBlockByID(ctx context.Context, id string) Error
+ DeleteBlockByID(ctx context.Context, id string) error
// DeleteBlockByURI removes block with given AP URI from the database.
- DeleteBlockByURI(ctx context.Context, uri string) Error
-
- // DeleteBlocksByOriginAccountID removes any blocks with accountID equal to originAccountID.
- DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) Error
+ DeleteBlockByURI(ctx context.Context, uri string) error
- // DeleteBlocksByTargetAccountID removes any blocks with given targetAccountID.
- DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) Error
+ // DeleteAccountBlocks will delete all database blocks to / from the given account ID.
+ DeleteAccountBlocks(ctx context.Context, accountID string) error
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
+ // GetFollowByID fetches follow with given ID from the database.
+ GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error)
+
+ // GetFollowByURI fetches follow with given AP URI from the database.
+ GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error)
+
+ // GetFollow retrieves a follow if it exists between source and target accounts.
+ GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
+
+ // GetFollowRequestByID fetches follow request with given ID from the database.
+ GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error)
+
+ // GetFollowRequestByURI fetches follow request with given AP URI from the database.
+ GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error)
+
+ // GetFollowRequest retrieves a follow request if it exists between source and target accounts.
+ GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error)
+
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
- IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
+
+ // IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
+ IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
- IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
- // IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
- IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
+ // PutFollow attempts to place the given account follow in the database.
+ PutFollow(ctx context.Context, follow *gtsmodel.Follow) error
+
+ // PutFollowRequest attempts to place the given account follow request in the database.
+ PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error
+
+ // DeleteFollowByID deletes a follow from the database with the given ID.
+ DeleteFollowByID(ctx context.Context, id string) error
+
+ // DeleteFollowByURI deletes a follow from the database with the given URI.
+ DeleteFollowByURI(ctx context.Context, uri string) error
+
+ // DeleteFollowRequestByID deletes a follow request from the database with the given ID.
+ DeleteFollowRequestByID(ctx context.Context, id string) error
+
+ // DeleteFollowRequestByURI deletes a follow request from the database with the given URI.
+ DeleteFollowRequestByURI(ctx context.Context, uri string) error
+
+ // DeleteAccountFollows will delete all database follows to / from the given account ID.
+ DeleteAccountFollows(ctx context.Context, accountID string) error
+
+ // DeleteAccountFollowRequests will delete all database follow requests to / from the given account ID.
+ DeleteAccountFollowRequests(ctx context.Context, accountID string) error
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
@@ -69,65 +113,41 @@ type Relationship interface {
AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
// RejectFollowRequest fetches a follow request from the database, and then deletes it.
- //
- // The deleted follow request will be returned so that further processing can be done on it.
- RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, Error)
+ RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) Error
- // GetFollows returns a slice of follows owned by the given accountID, and/or
- // targeting the given account id.
- //
- // If accountID is set and targetAccountID isn't, then all follows created by
- // accountID will be returned.
- //
- // If targetAccountID is set and accountID isn't, then all follows targeting
- // targetAccountID will be returned.
- //
- // If both accountID and targetAccountID are set, then only 0 or 1 follows will
- // be in the returned slice.
- GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, Error)
+ // GetAccountFollows returns a slice of follows owned by the given accountID.
+ GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
- // GetLocalFollowersIDs returns a list of local account IDs which follow the
- // targetAccountID. The returned IDs are not guaranteed to be ordered in any
- // particular way, so take care.
- GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, Error)
+ // GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance.
+ GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
- // CountFollows is like GetFollows, but just counts rather than returning.
- CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, Error)
+ // CountAccountFollows returns the amount of accounts that the given accountID is following.
+ CountAccountFollows(ctx context.Context, accountID string) (int, error)
- // GetFollowRequests returns a slice of follows requests owned by the given
- // accountID, and/or targeting the given account id.
- //
- // If accountID is set and targetAccountID isn't, then all requests created by
- // accountID will be returned.
- //
- // If targetAccountID is set and accountID isn't, then all requests targeting
- // targetAccountID will be returned.
- //
- // If both accountID and targetAccountID are set, then only 0 or 1 requests will
- // be in the returned slice.
- GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, Error)
+ // CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance.
+ CountAccountLocalFollows(ctx context.Context, accountID string) (int, error)
- // CountFollowRequests is like GetFollowRequests, but just counts rather than returning.
- CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, Error)
+ // GetAccountFollowers fetches follows that target given accountID.
+ GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
- // Unfollow removes a follow targeting targetAccountID and originating
- // from originAccountID.
- //
- // If a follow was removed this way, the AP URI of the follow will be
- // returned to the caller, so that further processing can take place
- // if necessary.
- //
- // If no follow was removed this way, the returned string will be empty.
- Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, Error)
+ // GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance.
+ GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
- // UnfollowRequest removes a follow request targeting targetAccountID
- // and originating from originAccountID.
- //
- // If a follow request was removed this way, the AP URI of the follow
- // request will be returned to the caller, so that further processing
- // can take place if necessary.
- //
- // If no follow request was removed this way, the returned string will
- // be empty.
- UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, Error)
+ // CountAccountFollowers returns the amounts that the given ID is followed by.
+ CountAccountFollowers(ctx context.Context, accountID string) (int, error)
+
+ // CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance.
+ CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error)
+
+ // GetAccountFollowRequests returns all follow requests targeting the given account.
+ GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
+
+ // GetAccountFollowRequesting returns all follow requests originating from the given account.
+ GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
+
+ // CountAccountFollowRequests returns number of follow requests targeting the given account.
+ CountAccountFollowRequests(ctx context.Context, accountID string) (int, error)
+
+ // CountAccountFollowerRequests returns number of follow requests originating from the given account.
+ CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error)
}
diff --git a/internal/db/status.go b/internal/db/status.go
index 16728983a..fdce19094 100644
--- a/internal/db/status.go
+++ b/internal/db/status.go
@@ -37,6 +37,9 @@ type Status interface {
// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
+ // PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc).
+ PopulateStatus(ctx context.Context, status *gtsmodel.Status) error
+
// PutStatus stores one status in the database.
PutStatus(ctx context.Context, status *gtsmodel.Status) Error
diff --git a/internal/db/statusfave.go b/internal/db/statusfave.go
index 2d55592aa..b435da514 100644
--- a/internal/db/statusfave.go
+++ b/internal/db/statusfave.go
@@ -24,22 +24,22 @@ import (
)
type StatusFave interface {
- // GetStatusFave returns one status fave with the given id.
- GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, Error)
-
// GetStatusFaveByAccountID gets one status fave created by the given
// accountID, targeting the given statusID.
- GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error)
+ GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error)
+
+ // GetStatusFave returns one status fave with the given id.
+ GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error)
+ GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error)
// PutStatusFave inserts the given statusFave into the database.
PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) Error
// DeleteStatusFave deletes one status fave with the given id.
- DeleteStatusFave(ctx context.Context, id string) Error
+ DeleteStatusFaveByID(ctx context.Context, id string) Error
// DeleteStatusFaves mass deletes status faves targeting targetAccountID
// and/or originating from originAccountID and/or faving statusID.