summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/api/client/account/sqlite-test.dbbin311296 -> 0 bytes
-rw-r--r--internal/api/client/fileserver/sqlite-test.dbbin307200 -> 0 bytes
-rw-r--r--internal/api/client/media/sqlite-test.dbbin311296 -> 0 bytes
-rw-r--r--internal/api/client/status/sqlite-test.dbbin311296 -> 0 bytes
-rw-r--r--internal/api/s2s/user/sqlite-test.dbbin311296 -> 0 bytes
-rw-r--r--internal/cache/account.go157
-rw-r--r--internal/cache/account_test.go63
-rw-r--r--internal/cache/status.go70
-rw-r--r--internal/cache/status_test.go66
-rw-r--r--internal/db/bundb/account.go86
-rw-r--r--internal/db/bundb/bundb.go25
-rw-r--r--internal/db/bundb/conn.go20
-rw-r--r--internal/db/bundb/relationship.go4
-rw-r--r--internal/db/bundb/relationship_test.go124
-rw-r--r--internal/db/bundb/sqlite-test.dbbin315392 -> 0 bytes
-rw-r--r--internal/db/bundb/status.go145
-rw-r--r--internal/db/bundb/status_test.go5
-rw-r--r--internal/db/status.go6
-rw-r--r--internal/federation/dereference.go8
-rw-r--r--internal/federation/dereferencing/account.go7
-rw-r--r--internal/federation/dereferencing/announce.go2
-rw-r--r--internal/federation/dereferencing/dereferencer.go4
-rw-r--r--internal/federation/dereferencing/sqlite-test.dbbin311296 -> 0 bytes
-rw-r--r--internal/federation/dereferencing/status.go56
-rw-r--r--internal/federation/dereferencing/status_test.go4
-rw-r--r--internal/federation/dereferencing/thread.go18
-rw-r--r--internal/federation/federator.go5
-rw-r--r--internal/federation/sqlite-test.dbbin315392 -> 0 bytes
-rw-r--r--internal/oauth/sqlite-test.dbbin307200 -> 0 bytes
-rw-r--r--internal/processing/fromfederator.go2
-rw-r--r--internal/processing/search.go2
-rw-r--r--internal/processing/status/sqlite-test.dbbin311296 -> 0 bytes
-rw-r--r--internal/text/sqlite-test.dbbin311296 -> 0 bytes
-rw-r--r--internal/timeline/sqlite-test.dbbin311296 -> 0 bytes
-rw-r--r--internal/typeutils/astointernal.go1
-rw-r--r--internal/typeutils/sqlite-test.dbbin311296 -> 0 bytes
36 files changed, 653 insertions, 227 deletions
diff --git a/internal/api/client/account/sqlite-test.db b/internal/api/client/account/sqlite-test.db
deleted file mode 100644
index eab8315d9..000000000
--- a/internal/api/client/account/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/api/client/fileserver/sqlite-test.db b/internal/api/client/fileserver/sqlite-test.db
deleted file mode 100644
index 5689e7edb..000000000
--- a/internal/api/client/fileserver/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/api/client/media/sqlite-test.db b/internal/api/client/media/sqlite-test.db
deleted file mode 100644
index 1ed985248..000000000
--- a/internal/api/client/media/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/api/client/status/sqlite-test.db b/internal/api/client/status/sqlite-test.db
deleted file mode 100644
index 448d10813..000000000
--- a/internal/api/client/status/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/api/s2s/user/sqlite-test.db b/internal/api/s2s/user/sqlite-test.db
deleted file mode 100644
index b67967b30..000000000
--- a/internal/api/s2s/user/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/cache/account.go b/internal/cache/account.go
new file mode 100644
index 000000000..bb402d60f
--- /dev/null
+++ b/internal/cache/account.go
@@ -0,0 +1,157 @@
+package cache
+
+import (
+ "sync"
+
+ "github.com/ReneKroon/ttlcache"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+// AccountCache is a wrapper around ttlcache.Cache to provide URL and URI lookups for gtsmodel.Account
+type AccountCache struct {
+ cache *ttlcache.Cache // map of IDs -> cached accounts
+ urls map[string]string // map of account URLs -> IDs
+ uris map[string]string // map of account URIs -> IDs
+ mutex sync.Mutex
+}
+
+// NewAccountCache returns a new instantiated AccountCache object
+func NewAccountCache() *AccountCache {
+ c := AccountCache{
+ cache: ttlcache.NewCache(),
+ urls: make(map[string]string, 100),
+ uris: make(map[string]string, 100),
+ mutex: sync.Mutex{},
+ }
+
+ // Set callback to purge lookup maps on expiration
+ c.cache.SetExpirationCallback(func(key string, value interface{}) {
+ account := value.(*gtsmodel.Account)
+
+ c.mutex.Lock()
+ delete(c.urls, account.URL)
+ delete(c.uris, account.URI)
+ c.mutex.Unlock()
+ })
+
+ return &c
+}
+
+// GetByID attempts to fetch a account from the cache by its ID, you will receive a copy for thread-safety
+func (c *AccountCache) GetByID(id string) (*gtsmodel.Account, bool) {
+ c.mutex.Lock()
+ account, ok := c.getByID(id)
+ c.mutex.Unlock()
+ return account, ok
+}
+
+// GetByURL attempts to fetch a account from the cache by its URL, you will receive a copy for thread-safety
+func (c *AccountCache) GetByURL(url string) (*gtsmodel.Account, bool) {
+ // Perform safe ID lookup
+ c.mutex.Lock()
+ id, ok := c.urls[url]
+
+ // Not found, unlock early
+ if !ok {
+ c.mutex.Unlock()
+ return nil, false
+ }
+
+ // Attempt account lookup
+ account, ok := c.getByID(id)
+ c.mutex.Unlock()
+ return account, ok
+}
+
+// GetByURI attempts to fetch a account from the cache by its URI, you will receive a copy for thread-safety
+func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) {
+ // Perform safe ID lookup
+ c.mutex.Lock()
+ id, ok := c.uris[uri]
+
+ // Not found, unlock early
+ if !ok {
+ c.mutex.Unlock()
+ return nil, false
+ }
+
+ // Attempt account lookup
+ account, ok := c.getByID(id)
+ c.mutex.Unlock()
+ return account, ok
+}
+
+// getByID performs an unsafe (no mutex locks) lookup of account by ID, returning a copy of account in cache
+func (c *AccountCache) getByID(id string) (*gtsmodel.Account, bool) {
+ v, ok := c.cache.Get(id)
+ if !ok {
+ return nil, false
+ }
+ return copyAccount(v.(*gtsmodel.Account)), true
+}
+
+// Put places a account in the cache, ensuring that the object place is a copy for thread-safety
+func (c *AccountCache) Put(account *gtsmodel.Account) {
+ if account == nil || account.ID == "" {
+ panic("invalid account")
+ }
+
+ c.mutex.Lock()
+ c.cache.Set(account.ID, copyAccount(account))
+ if account.URL != "" {
+ c.urls[account.URL] = account.ID
+ }
+ if account.URI != "" {
+ c.uris[account.URI] = account.ID
+ }
+ c.mutex.Unlock()
+}
+
+// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects.
+// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
+// this should be a relatively cheap process
+func copyAccount(account *gtsmodel.Account) *gtsmodel.Account {
+ return &gtsmodel.Account{
+ ID: account.ID,
+ Username: account.Username,
+ Domain: account.Domain,
+ AvatarMediaAttachmentID: account.AvatarMediaAttachmentID,
+ AvatarMediaAttachment: nil,
+ AvatarRemoteURL: account.AvatarRemoteURL,
+ HeaderMediaAttachmentID: account.HeaderMediaAttachmentID,
+ HeaderMediaAttachment: nil,
+ HeaderRemoteURL: account.HeaderRemoteURL,
+ DisplayName: account.DisplayName,
+ Fields: account.Fields,
+ Note: account.Note,
+ Memorial: account.Memorial,
+ MovedToAccountID: account.MovedToAccountID,
+ CreatedAt: account.CreatedAt,
+ UpdatedAt: account.UpdatedAt,
+ Bot: account.Bot,
+ Reason: account.Reason,
+ Locked: account.Locked,
+ Discoverable: account.Discoverable,
+ Privacy: account.Privacy,
+ Sensitive: account.Sensitive,
+ Language: account.Language,
+ URI: account.URI,
+ URL: account.URL,
+ LastWebfingeredAt: account.LastWebfingeredAt,
+ InboxURI: account.InboxURI,
+ OutboxURI: account.OutboxURI,
+ FollowingURI: account.FollowingURI,
+ FollowersURI: account.FollowersURI,
+ FeaturedCollectionURI: account.FeaturedCollectionURI,
+ ActorType: account.ActorType,
+ AlsoKnownAs: account.AlsoKnownAs,
+ PrivateKey: account.PrivateKey,
+ PublicKey: account.PublicKey,
+ PublicKeyURI: account.PublicKeyURI,
+ SensitizedAt: account.SensitizedAt,
+ SilencedAt: account.SilencedAt,
+ SuspendedAt: account.SuspendedAt,
+ HideCollections: account.HideCollections,
+ SuspensionOrigin: account.SuspensionOrigin,
+ }
+}
diff --git a/internal/cache/account_test.go b/internal/cache/account_test.go
new file mode 100644
index 000000000..f84ad2261
--- /dev/null
+++ b/internal/cache/account_test.go
@@ -0,0 +1,63 @@
+package cache_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type AccountCacheTestSuite struct {
+ suite.Suite
+ data map[string]*gtsmodel.Account
+ cache *cache.AccountCache
+}
+
+func (suite *AccountCacheTestSuite) SetupSuite() {
+ suite.data = testrig.NewTestAccounts()
+}
+
+func (suite *AccountCacheTestSuite) SetupTest() {
+ suite.cache = cache.NewAccountCache()
+}
+
+func (suite *AccountCacheTestSuite) TearDownTest() {
+ suite.data = nil
+ suite.cache = nil
+}
+
+func (suite *AccountCacheTestSuite) TestAccountCache() {
+ for _, account := range suite.data {
+ // Place in the cache
+ suite.cache.Put(account)
+ }
+
+ for _, account := range suite.data {
+ var ok bool
+ var check *gtsmodel.Account
+
+ // Check we can retrieve
+ check, ok = suite.cache.GetByID(account.ID)
+ if !ok && !accountIs(account, check) {
+ suite.Fail("Failed to fetch expected account with ID: %s", account.ID)
+ }
+ check, ok = suite.cache.GetByURI(account.URI)
+ if account.URI != "" && !ok && !accountIs(account, check) {
+ suite.Fail("Failed to fetch expected account with URI: %s", account.URI)
+ }
+ check, ok = suite.cache.GetByURL(account.URL)
+ if account.URL != "" && !ok && !accountIs(account, check) {
+ suite.Fail("Failed to fetch expected account with URL: %s", account.URL)
+ }
+ }
+}
+
+func TestAccountCache(t *testing.T) {
+ suite.Run(t, &AccountCacheTestSuite{})
+}
+
+func accountIs(account1, account2 *gtsmodel.Account) bool {
+ return account1.ID == account2.ID && account1.URI == account2.URI && account1.URL == account2.URL
+}
diff --git a/internal/cache/status.go b/internal/cache/status.go
index 895a5692c..028abc8f7 100644
--- a/internal/cache/status.go
+++ b/internal/cache/status.go
@@ -37,7 +37,7 @@ func NewStatusCache() *StatusCache {
return &c
}
-// GetByID attempts to fetch a status from the cache by its ID
+// GetByID attempts to fetch a status from the cache by its ID, you will receive a copy for thread-safety
func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) {
c.mutex.Lock()
status, ok := c.getByID(id)
@@ -45,7 +45,7 @@ func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) {
return status, ok
}
-// GetByURL attempts to fetch a status from the cache by its URL
+// GetByURL attempts to fetch a status from the cache by its URL, you will receive a copy for thread-safety
func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) {
// Perform safe ID lookup
c.mutex.Lock()
@@ -63,7 +63,7 @@ func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) {
return status, ok
}
-// GetByURI attempts to fetch a status from the cache by its URI
+// GetByURI attempts to fetch a status from the cache by its URI, you will receive a copy for thread-safety
func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) {
// Perform safe ID lookup
c.mutex.Lock()
@@ -81,26 +81,72 @@ func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) {
return status, ok
}
-// getByID performs an unsafe (no mutex locks) lookup of status by ID
+// getByID performs an unsafe (no mutex locks) lookup of status by ID, returning a copy of status in cache
func (c *StatusCache) getByID(id string) (*gtsmodel.Status, bool) {
v, ok := c.cache.Get(id)
if !ok {
return nil, false
}
- return v.(*gtsmodel.Status), true
+ return copyStatus(v.(*gtsmodel.Status)), true
}
-// Put places a status in the cache
+// Put places a status in the cache, ensuring that the object place is a copy for thread-safety
func (c *StatusCache) Put(status *gtsmodel.Status) {
- if status == nil || status.ID == "" ||
- status.URL == "" ||
- status.URI == "" {
+ if status == nil || status.ID == "" {
panic("invalid status")
}
c.mutex.Lock()
- c.cache.Set(status.ID, status)
- c.urls[status.URL] = status.ID
- c.uris[status.URI] = status.ID
+ c.cache.Set(status.ID, copyStatus(status))
+ if status.URL != "" {
+ c.urls[status.URL] = status.ID
+ }
+ if status.URI != "" {
+ c.uris[status.URI] = status.ID
+ }
c.mutex.Unlock()
}
+
+// copyStatus performs a surface-level copy of status, only keeping attached IDs intact, not the objects.
+// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
+// this should be a relatively cheap process
+func copyStatus(status *gtsmodel.Status) *gtsmodel.Status {
+ return &gtsmodel.Status{
+ ID: status.ID,
+ URI: status.URI,
+ URL: status.URL,
+ Content: status.Content,
+ AttachmentIDs: status.AttachmentIDs,
+ Attachments: nil,
+ TagIDs: status.TagIDs,
+ Tags: nil,
+ MentionIDs: status.MentionIDs,
+ Mentions: nil,
+ EmojiIDs: status.EmojiIDs,
+ Emojis: nil,
+ CreatedAt: status.CreatedAt,
+ UpdatedAt: status.UpdatedAt,
+ Local: status.Local,
+ AccountID: status.AccountID,
+ Account: nil,
+ AccountURI: status.AccountURI,
+ InReplyToID: status.InReplyToID,
+ InReplyTo: nil,
+ InReplyToURI: status.InReplyToURI,
+ InReplyToAccountID: status.InReplyToAccountID,
+ InReplyToAccount: nil,
+ BoostOfID: status.BoostOfID,
+ BoostOf: nil,
+ BoostOfAccountID: status.BoostOfAccountID,
+ BoostOfAccount: nil,
+ ContentWarning: status.ContentWarning,
+ Visibility: status.Visibility,
+ Sensitive: status.Sensitive,
+ Language: status.Language,
+ CreatedWithApplicationID: status.CreatedWithApplicationID,
+ VisibilityAdvanced: status.VisibilityAdvanced,
+ ActivityStreamsType: status.ActivityStreamsType,
+ Text: status.Text,
+ Pinned: status.Pinned,
+ }
+}
diff --git a/internal/cache/status_test.go b/internal/cache/status_test.go
index 10dee5bca..222961025 100644
--- a/internal/cache/status_test.go
+++ b/internal/cache/status_test.go
@@ -3,39 +3,61 @@ package cache_test
import (
"testing"
+ "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/testrig"
)
-func TestStatusCache(t *testing.T) {
- cache := cache.NewStatusCache()
+type StatusCacheTestSuite struct {
+ suite.Suite
+ data map[string]*gtsmodel.Status
+ cache *cache.StatusCache
+}
- // Attempt to place a status
- status := gtsmodel.Status{
- ID: "id",
- URI: "uri",
- URL: "url",
- }
- cache.Put(&status)
+func (suite *StatusCacheTestSuite) SetupSuite() {
+ suite.data = testrig.NewTestStatuses()
+}
- var ok bool
- var check *gtsmodel.Status
+func (suite *StatusCacheTestSuite) SetupTest() {
+ suite.cache = cache.NewStatusCache()
+}
- // Check we can retrieve
- check, ok = cache.GetByID(status.ID)
- if !ok || !statusIs(&status, check) {
- t.Fatal("Could not find expected status")
- }
- check, ok = cache.GetByURI(status.URI)
- if !ok || !statusIs(&status, check) {
- t.Fatal("Could not find expected status")
+func (suite *StatusCacheTestSuite) TearDownTest() {
+ suite.data = nil
+ suite.cache = nil
+}
+
+func (suite *StatusCacheTestSuite) TestStatusCache() {
+ for _, status := range suite.data {
+ // Place in the cache
+ suite.cache.Put(status)
}
- check, ok = cache.GetByURL(status.URL)
- if !ok || !statusIs(&status, check) {
- t.Fatal("Could not find expected status")
+
+ for _, status := range suite.data {
+ var ok bool
+ var check *gtsmodel.Status
+
+ // Check we can retrieve
+ check, ok = suite.cache.GetByID(status.ID)
+ if !ok && !statusIs(status, check) {
+ suite.Fail("Failed to fetch expected account with ID: %s", status.ID)
+ }
+ check, ok = suite.cache.GetByURI(status.URI)
+ if status.URI != "" && !ok && !statusIs(status, check) {
+ suite.Fail("Failed to fetch expected account with URI: %s", status.URI)
+ }
+ check, ok = suite.cache.GetByURL(status.URL)
+ if status.URL != "" && !ok && !statusIs(status, check) {
+ suite.Fail("Failed to fetch expected account with URL: %s", status.URL)
+ }
}
}
+func TestStatusCache(t *testing.T) {
+ suite.Run(t, &StatusCacheTestSuite{})
+}
+
func statusIs(status1, status2 *gtsmodel.Status) bool {
return status1.ID == status2.ID && status1.URI == status2.URI && status1.URL == status2.URL
}
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index d7d45a739..32a70f7cd 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -25,6 +25,7 @@ import (
"strings"
"time"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -34,6 +35,7 @@ import (
type accountDB struct {
config *config.Config
conn *DBConn
+ cache *cache.AccountCache
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
@@ -45,60 +47,80 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
}
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
- account := new(gtsmodel.Account)
-
- q := a.newAccountQ(account).
- Where("account.id = ?", id)
-
- err := q.Scan(ctx)
- if err != nil {
- return nil, a.conn.ProcessError(err)
- }
- return account, nil
+ return a.getAccount(
+ ctx,
+ func() (*gtsmodel.Account, bool) {
+ return a.cache.GetByID(id)
+ },
+ func(account *gtsmodel.Account) error {
+ return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx)
+ },
+ )
}
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
- account := new(gtsmodel.Account)
-
- q := a.newAccountQ(account).
- Where("account.uri = ?", uri)
+ return a.getAccount(
+ ctx,
+ func() (*gtsmodel.Account, bool) {
+ return a.cache.GetByURI(uri)
+ },
+ func(account *gtsmodel.Account) error {
+ return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx)
+ },
+ )
+}
- err := q.Scan(ctx)
- if err != nil {
- return nil, a.conn.ProcessError(err)
- }
- return account, nil
+func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ func() (*gtsmodel.Account, bool) {
+ return a.cache.GetByURL(url)
+ },
+ func(account *gtsmodel.Account) error {
+ return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx)
+ },
+ )
}
-func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
- account := new(gtsmodel.Account)
+func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {
+ // Attempt to fetch cached account
+ account, cached := cacheGet()
- q := a.newAccountQ(account).
- Where("account.url = ?", uri)
+ if !cached {
+ account = &gtsmodel.Account{}
- err := q.Scan(ctx)
- if err != nil {
- return nil, a.conn.ProcessError(err)
+ // Not cached! Perform database query
+ err := dbQuery(account)
+ if err != nil {
+ return nil, a.conn.ProcessError(err)
+ }
+
+ // Place in the cache
+ a.cache.Put(account)
}
+
return account, nil
}
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
if strings.TrimSpace(account.ID) == "" {
+ // TODO: we should not need this check here
return nil, errors.New("account had no ID")
}
+ // Update the account's last-used
account.UpdatedAt = time.Now()
- q := a.conn.
- NewUpdate().
- Model(account).
- WherePK()
-
- _, err := q.Exec(ctx)
+ // Update the account model in the DB
+ _, err := a.conn.NewUpdate().Model(account).WherePK().Exec(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
+
+ // Place updated account in cache
+ // (this will replace existing, i.e. invalidating)
+ a.cache.Put(account)
+
return account, nil
}
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 248232fe3..6fcc56e51 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -91,6 +91,15 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
conn = WrapDBConn(bun.NewDB(sqldb, pgdialect.New()), log)
case dbTypeSqlite:
// SQLITE
+
+ // Drop anything fancy from DB address
+ c.DBConfig.Address = strings.Split(c.DBConfig.Address, "?")[0]
+ c.DBConfig.Address = strings.TrimPrefix(c.DBConfig.Address, "file:")
+
+ // Append our own SQLite preferences
+ c.DBConfig.Address = "file:" + c.DBConfig.Address + "?cache=shared"
+
+ // Open new DB instance
var err error
sqldb, err = sql.Open("sqlite", c.DBConfig.Address)
if err != nil {
@@ -98,7 +107,7 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
}
conn = WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()), log)
- if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") {
+ if c.DBConfig.Address == "file::memory:?cache=shared" {
log.Warn("sqlite in-memory database should only be used for debugging")
// don't close connections on disconnect -- otherwise
@@ -121,11 +130,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
conn.RegisterModel(t)
}
+ accounts := &accountDB{config: c, conn: conn, cache: cache.NewAccountCache()}
+
ps := &bunDBService{
- Account: &accountDB{
- config: c,
- conn: conn,
- },
+ Account: accounts,
Admin: &adminDB{
config: c,
conn: conn,
@@ -165,9 +173,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
conn: conn,
},
Status: &statusDB{
- config: c,
- conn: conn,
- cache: cache.NewStatusCache(),
+ config: c,
+ conn: conn,
+ cache: cache.NewStatusCache(),
+ accounts: accounts,
},
Timeline: &timelineDB{
config: c,
diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go
index 698adff3d..abaebcebd 100644
--- a/internal/db/bundb/conn.go
+++ b/internal/db/bundb/conn.go
@@ -12,6 +12,8 @@ import (
// dbConn wrapps a bun.DB conn to provide SQL-type specific additional functionality
type DBConn struct {
+ // TODO: move *Config here, no need to be in each struct type
+
errProc func(error) db.Error // errProc is the SQL-type specific error processor
log *logrus.Logger // log is the logger passed with this DBConn
*bun.DB // DB is the underlying bun.DB connection
@@ -35,6 +37,24 @@ func WrapDBConn(dbConn *bun.DB, log *logrus.Logger) *DBConn {
}
}
+func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error {
+ // Acquire a new transaction
+ tx, err := conn.BeginTx(ctx, nil)
+ if err != nil {
+ return conn.ProcessError(err)
+ }
+
+ // Perform supplied transaction
+ if err = fn(tx); err != nil {
+ tx.Rollback() //nolint
+ return conn.ProcessError(err)
+ }
+
+ // Finally, commit transaction
+ err = tx.Commit()
+ return conn.ProcessError(err)
+}
+
// ProcessError processes an error to replace any known values with our own db.Error types,
// making it easier to catch specific situations (e.g. no rows, already exists, etc)
func (conn *DBConn) ProcessError(err error) db.Error {
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index 56b752593..64d896527 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -237,7 +237,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
if _, err := r.conn.
NewInsert().
Model(follow).
- On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
+ On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
@@ -298,7 +298,7 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
if localOnly {
q = q.ColumnExpr("follow.*").
- Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
+ Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)").
Where("follow.target_account_id = ?", accountID).
WhereGroup(" AND ", whereEmptyOrNull("a.domain"))
} else {
diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go
new file mode 100644
index 000000000..dcc71b37c
--- /dev/null
+++ b/internal/db/bundb/relationship_test.go
@@ -0,0 +1,124 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type RelationshipTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *RelationshipTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *RelationshipTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *RelationshipTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *RelationshipTestSuite) TestIsBlocked() {
+ suite.Suite.T().Skip("TODO: implement")
+}
+
+func (suite *RelationshipTestSuite) TestGetBlock() {
+ suite.Suite.T().Skip("TODO: implement")
+}
+
+func (suite *RelationshipTestSuite) TestGetRelationship() {
+ suite.Suite.T().Skip("TODO: implement")
+}
+
+func (suite *RelationshipTestSuite) TestIsFollowing() {
+ suite.Suite.T().Skip("TODO: implement")
+}
+
+func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
+ suite.Suite.T().Skip("TODO: implement")
+}
+
+func (suite *RelationshipTestSuite) AcceptFollowRequest() {
+ for _, account := range suite.testAccounts {
+ _, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID")
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ suite.Suite.Fail("error accepting follow request: %v", err)
+ }
+ }
+}
+
+func (suite *RelationshipTestSuite) GetAccountFollowRequests() {
+ suite.Suite.T().Skip("TODO: implement")
+}
+
+func (suite *RelationshipTestSuite) GetAccountFollows() {
+ suite.Suite.T().Skip("TODO: implement")
+}
+
+func (suite *RelationshipTestSuite) CountAccountFollows() {
+ suite.Suite.T().Skip("TODO: implement")
+}
+
+func (suite *RelationshipTestSuite) GetAccountFollowedBy() {
+ // TODO: more comprehensive tests here
+
+ for _, account := range suite.testAccounts {
+ var err error
+
+ _, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
+ if err != nil {
+ suite.Suite.Fail("error checking accounts followed by: %v", err)
+ }
+
+ _, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
+ if err != nil {
+ suite.Suite.Fail("error checking localOnly accounts followed by: %v", err)
+ }
+ }
+}
+
+func (suite *RelationshipTestSuite) CountAccountFollowedBy() {
+ suite.Suite.T().Skip("TODO: implement")
+}
+
+func TestRelationshipTestSuite(t *testing.T) {
+ suite.Run(t, new(RelationshipTestSuite))
+}
diff --git a/internal/db/bundb/sqlite-test.db b/internal/db/bundb/sqlite-test.db
deleted file mode 100644
index ed3b25ee3..000000000
--- a/internal/db/bundb/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index 1d5acf0fc..9464cfadf 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -21,7 +21,6 @@ package bundb
import (
"container/list"
"context"
- "errors"
"time"
"github.com/superseriousbusiness/gotosocial/internal/cache"
@@ -35,6 +34,11 @@ type statusDB struct {
config *config.Config
conn *DBConn
cache *cache.StatusCache
+
+ // TODO: keep method definitions in same place but instead have receiver
+ // all point to one single "db" type, so they can all share methods
+ // and caches where necessary
+ accounts *accountDB
}
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
@@ -51,30 +55,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
Relation("CreatedWithApplication")
}
-func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
- if status.InReplyToID != "" && status.InReplyTo == nil {
- // TODO: do we want to keep this possibly recursive strategy?
-
- if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached {
- status.InReplyTo = inReplyTo
- } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
- status.InReplyTo = inReplyTo
- }
- }
-
- if status.BoostOfID != "" && status.BoostOf == nil {
- // TODO: do we want to keep this possibly recursive strategy?
-
- if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached {
- status.BoostOf = boostOf
- } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
- status.BoostOf = boostOf
- }
- }
-
- return status
-}
-
func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
@@ -85,64 +65,79 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
}
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.cache.GetByID(id); cached {
- return status, nil
- }
-
- status := &gtsmodel.Status{}
-
- q := s.newStatusQ(status).
- Where("status.id = ?", id)
-
- err := q.Scan(ctx)
- if err != nil {
- return nil, s.conn.ProcessError(err)
- }
-
- s.cache.Put(status)
- return s.getAttachedStatuses(ctx, status), nil
+ return s.getStatus(
+ ctx,
+ func() (*gtsmodel.Status, bool) {
+ return s.cache.GetByID(id)
+ },
+ func(status *gtsmodel.Status) error {
+ return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx)
+ },
+ )
}
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.cache.GetByURI(uri); cached {
- return status, nil
- }
-
- status := &gtsmodel.Status{}
+ return s.getStatus(
+ ctx,
+ func() (*gtsmodel.Status, bool) {
+ return s.cache.GetByURI(uri)
+ },
+ func(status *gtsmodel.Status) error {
+ return s.newStatusQ(status).Where("LOWER(status.uri) = LOWER(?)", uri).Scan(ctx)
+ },
+ )
+}
- q := s.newStatusQ(status).
- Where("LOWER(status.uri) = LOWER(?)", uri)
+func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
+ return s.getStatus(
+ ctx,
+ func() (*gtsmodel.Status, bool) {
+ return s.cache.GetByURL(url)
+ },
+ func(status *gtsmodel.Status) error {
+ return s.newStatusQ(status).Where("LOWER(status.url) = LOWER(?)", url).Scan(ctx)
+ },
+ )
+}
- err := q.Scan(ctx)
- if err != nil {
- return nil, s.conn.ProcessError(err)
- }
+func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) {
+ // Attempt to fetch cached status
+ status, cached := cacheGet()
- s.cache.Put(status)
- return s.getAttachedStatuses(ctx, status), nil
-}
+ if !cached {
+ status = &gtsmodel.Status{}
-func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.cache.GetByURL(url); cached {
- return status, nil
- }
+ // Not cached! Perform database query
+ err := dbQuery(status)
+ if err != nil {
+ return nil, s.conn.ProcessError(err)
+ }
- status := &gtsmodel.Status{}
+ // If there is boosted, fetch from DB also
+ if status.BoostOfID != "" {
+ boostOf, err := s.GetStatusByID(ctx, status.BoostOfID)
+ if err == nil {
+ status.BoostOf = boostOf
+ }
+ }
- q := s.newStatusQ(status).
- Where("LOWER(status.url) = LOWER(?)", url)
+ // Place in the cache
+ s.cache.Put(status)
+ }
- err := q.Scan(ctx)
+ // Set the status author account
+ author, err := s.accounts.GetAccountByID(ctx, status.AccountID)
if err != nil {
- return nil, s.conn.ProcessError(err)
+ return nil, err
}
- s.cache.Put(status)
- return s.getAttachedStatuses(ctx, status), nil
+ // Return the prepared status
+ status.Account = author
+ return status, nil
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
- transaction := func(ctx context.Context, tx bun.Tx) error {
+ return s.conn.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.NewInsert().Model(&gtsmodel.StatusToEmoji{
@@ -174,10 +169,10 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
}
}
+ // Finally, insert the status
_, err := tx.NewInsert().Model(status).Exec(ctx)
return err
- }
- return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction))
+ })
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
@@ -210,12 +205,8 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
// only append children, not the overall parent status
+ entry := e.Value.(*gtsmodel.Status)
if entry.ID != status.ID {
children = append(children, entry)
}
@@ -242,11 +233,7 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,
for _, child := range immediateChildren {
insertLoop:
for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
+ entry := e.Value.(*gtsmodel.Status)
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
foundStatuses.InsertAfter(child, e)
break insertLoop
diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go
index 4f846441b..7acc86ff9 100644
--- a/internal/db/bundb/status_test.go
+++ b/internal/db/bundb/status_test.go
@@ -105,10 +105,9 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
- suite.NotEmpty(status.Mentions)
suite.NotEmpty(status.MentionIDs)
- suite.NotNil(status.InReplyTo)
- suite.NotNil(status.InReplyToAccount)
+ suite.NotEmpty(status.InReplyToID)
+ suite.NotEmpty(status.InReplyToAccountID)
}
func (suite *StatusTestSuite) TestGetStatusTwice() {
diff --git a/internal/db/status.go b/internal/db/status.go
index 7430433c4..f26f8942e 100644
--- a/internal/db/status.go
+++ b/internal/db/status.go
@@ -26,13 +26,13 @@ import (
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface {
- // GetStatusByID returns one status from the database, with all rel fields populated (if possible).
+ // GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error)
- // GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
+ // GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error)
- // GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
+ // GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// PutStatus stores one status in the database.
diff --git a/internal/federation/dereference.go b/internal/federation/dereference.go
index a09f0f84b..a9dbabb42 100644
--- a/internal/federation/dereference.go
+++ b/internal/federation/dereference.go
@@ -34,12 +34,12 @@ func (f *federator) EnrichRemoteAccount(ctx context.Context, username string, ac
return f.dereferencer.EnrichRemoteAccount(ctx, username, account)
}
-func (f *federator) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
- return f.dereferencer.GetRemoteStatus(ctx, username, remoteStatusID, refresh)
+func (f *federator) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
+ return f.dereferencer.GetRemoteStatus(ctx, username, remoteStatusID, refresh, includeParent, includeChilds)
}
-func (f *federator) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) {
- return f.dereferencer.EnrichRemoteStatus(ctx, username, status)
+func (f *federator) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error) {
+ return f.dereferencer.EnrichRemoteStatus(ctx, username, status, includeParent, includeChilds)
}
func (f *federator) DereferenceRemoteThread(ctx context.Context, username string, statusIRI *url.URL) error {
diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go
index 2eee0645d..8cae002e8 100644
--- a/internal/federation/dereferencing/account.go
+++ b/internal/federation/dereferencing/account.go
@@ -48,7 +48,6 @@ func instanceAccount(account *gtsmodel.Account) bool {
// EnrichRemoteAccount is mostly useful for calling after an account has been initially created by
// the federatingDB's Create function, or during the federated authorization flow.
func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) {
-
// if we're dealing with an instance account, we don't need to update anything
if instanceAccount(account) {
return account, nil
@@ -58,13 +57,13 @@ func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, accoun
return nil, err
}
- var err error
- account, err = d.db.UpdateAccount(ctx, account)
+ updated, err := d.db.UpdateAccount(ctx, account)
if err != nil {
d.log.Errorf("EnrichRemoteAccount: error updating account: %s", err)
+ return account, nil
}
- return account, nil
+ return updated, nil
}
// GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account,
diff --git a/internal/federation/dereferencing/announce.go b/internal/federation/dereferencing/announce.go
index 33af74ebe..d5cc5ad0c 100644
--- a/internal/federation/dereferencing/announce.go
+++ b/internal/federation/dereferencing/announce.go
@@ -46,7 +46,7 @@ func (d *deref) DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Stat
return fmt.Errorf("DereferenceAnnounce: error dereferencing thread of boosted status: %s", err)
}
- boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false)
+ boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false, false, false)
if err != nil {
return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.BoostOf.URI, err)
}
diff --git a/internal/federation/dereferencing/dereferencer.go b/internal/federation/dereferencing/dereferencer.go
index 4191bd283..8ad21013f 100644
--- a/internal/federation/dereferencing/dereferencer.go
+++ b/internal/federation/dereferencing/dereferencer.go
@@ -38,8 +38,8 @@ type Dereferencer interface {
GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)
EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error)
- GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error)
- EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error)
+ GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error)
+ EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error)
GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error)
diff --git a/internal/federation/dereferencing/sqlite-test.db b/internal/federation/dereferencing/sqlite-test.db
deleted file mode 100644
index bef45b3af..000000000
--- a/internal/federation/dereferencing/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go
index 3fa1e4133..7a7f928f1 100644
--- a/internal/federation/dereferencing/status.go
+++ b/internal/federation/dereferencing/status.go
@@ -39,8 +39,8 @@ import (
//
// EnrichRemoteStatus is mostly useful for calling after a status has been initially created by
// the federatingDB's Create function, but additional dereferencing is needed on it.
-func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) {
- if err := d.populateStatusFields(ctx, status, username); err != nil {
+func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error) {
+ if err := d.populateStatusFields(ctx, status, username, includeParent, includeChilds); err != nil {
return nil, err
}
@@ -62,7 +62,7 @@ func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status
// If a dereference was performed, then the function also returns the ap.Statusable representation for further processing.
//
// SIDE EFFECTS: remote status will be stored in the database, and the remote status owner will also be stored.
-func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
+func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
new := true
// check if we already have the status in our db
@@ -105,7 +105,7 @@ func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStat
}
gtsStatus.ID = ulid
- if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil {
+ if err := d.populateStatusFields(ctx, gtsStatus, username, includeParent, includeChilds); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)
}
@@ -115,7 +115,7 @@ func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStat
} else {
gtsStatus.ID = maybeStatus.ID
- if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil {
+ if err := d.populateStatusFields(ctx, gtsStatus, username, includeParent, includeChilds); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)
}
@@ -235,7 +235,7 @@ func (d *deref) dereferenceStatusable(ctx context.Context, username string, remo
// This function will deference all of the above, insert them in the database as necessary,
// and attach them to the status. The status itself will not be added to the database yet,
// that's up the caller to do.
-func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string) error {
+func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string, includeParent, includeChilds bool) error {
l := d.log.WithFields(logrus.Fields{
"func": "dereferenceStatusFields",
"status": fmt.Sprintf("%+v", status),
@@ -275,14 +275,19 @@ func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Statu
// 3. Emojis
// TODO
- // 4. Mentions
- if err := d.populateStatusMentions(ctx, status, requestingUsername); err != nil {
- return fmt.Errorf("populateStatusFields: error populating status mentions: %s", err)
+ // 4. Mentions (only if requested)
+ // TODO: do we need to handle removing empty mention objects and just using mention IDs slice?
+ if includeChilds {
+ if err := d.populateStatusMentions(ctx, status, requestingUsername); err != nil {
+ return fmt.Errorf("populateStatusFields: error populating status mentions: %s", err)
+ }
}
- // 5. Replied-to-status.
- if err := d.populateStatusRepliedTo(ctx, status, requestingUsername); err != nil {
- return fmt.Errorf("populateStatusFields: error populating status repliedTo: %s", err)
+ // 5. Replied-to-status (only if requested)
+ if includeParent {
+ if err := d.populateStatusRepliedTo(ctx, status, requestingUsername); err != nil {
+ return fmt.Errorf("populateStatusFields: error populating status repliedTo: %s", err)
+ }
}
return nil
@@ -391,7 +396,6 @@ func (d *deref) populateStatusAttachments(ctx context.Context, status *gtsmodel.
attachments := []*gtsmodel.MediaAttachment{}
for _, a := range status.Attachments {
-
aURL, err := url.Parse(a.RemoteURL)
if err != nil {
l.Errorf("populateStatusAttachments: couldn't parse attachment url %s: %s", a.RemoteURL, err)
@@ -401,6 +405,7 @@ func (d *deref) populateStatusAttachments(ctx context.Context, status *gtsmodel.
attachment, err := d.GetRemoteAttachment(ctx, requestingUsername, aURL, status.AccountID, status.ID, a.File.ContentType)
if err != nil {
l.Errorf("populateStatusAttachments: couldn't get remote attachment %s: %s", a.RemoteURL, err)
+ continue
}
attachmentIDs = append(attachmentIDs, attachment.ID)
@@ -420,29 +425,16 @@ func (d *deref) populateStatusRepliedTo(ctx context.Context, status *gtsmodel.St
return err
}
- var replyToStatus *gtsmodel.Status
- errs := []string{}
-
// see if we have the status in our db already
- if s, err := d.db.GetStatusByURI(ctx, status.InReplyToURI); err != nil {
- errs = append(errs, err.Error())
- } else {
- replyToStatus = s
- }
-
- if replyToStatus == nil {
- // didn't find the status in our db, try to get it remotely
- if s, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, statusURI, false); err != nil {
- errs = append(errs, err.Error())
- } else {
- replyToStatus = s
+ replyToStatus, err := d.db.GetStatusByURI(ctx, status.InReplyToURI)
+ if err != nil {
+ // Status was not in the DB, try fetch
+ replyToStatus, _, _, err = d.GetRemoteStatus(ctx, requestingUsername, statusURI, false, false, false)
+ if err != nil {
+ return fmt.Errorf("populateStatusRepliedTo: couldn't get reply to status with uri %s: %s", status.InReplyToURI, err)
}
}
- if replyToStatus == nil {
- return fmt.Errorf("populateStatusRepliedTo: couldn't get reply to status with uri %s: %s", statusURI, strings.Join(errs, " : "))
- }
-
// we have the status
status.InReplyToID = replyToStatus.ID
status.InReplyTo = replyToStatus
diff --git a/internal/federation/dereferencing/status_test.go b/internal/federation/dereferencing/status_test.go
index 2d259682b..43732ac77 100644
--- a/internal/federation/dereferencing/status_test.go
+++ b/internal/federation/dereferencing/status_test.go
@@ -119,7 +119,7 @@ func (suite *StatusTestSuite) TestDereferenceSimpleStatus() {
fetchingAccount := suite.testAccounts["local_account_1"]
statusURL := testrig.URLMustParse("https://unknown-instance.com/users/brand_new_person/statuses/01FE4NTHKWW7THT67EF10EB839")
- status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false)
+ status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false, false, false)
suite.NoError(err)
suite.NotNil(status)
suite.NotNil(statusable)
@@ -157,7 +157,7 @@ func (suite *StatusTestSuite) TestDereferenceStatusWithMention() {
fetchingAccount := suite.testAccounts["local_account_1"]
statusURL := testrig.URLMustParse("https://unknown-instance.com/users/brand_new_person/statuses/01FE5Y30E3W4P7TRE0R98KAYQV")
- status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false)
+ status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false, false, true)
suite.NoError(err)
suite.NotNil(status)
suite.NotNil(statusable)
diff --git a/internal/federation/dereferencing/thread.go b/internal/federation/dereferencing/thread.go
index f9dd9aa09..af16c01b2 100644
--- a/internal/federation/dereferencing/thread.go
+++ b/internal/federation/dereferencing/thread.go
@@ -49,7 +49,7 @@ func (d *deref) DereferenceThread(ctx context.Context, username string, statusIR
}
// first make sure we have this status in our db
- _, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true)
+ _, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true, false, false)
if err != nil {
return fmt.Errorf("DereferenceThread: error getting status with id %s: %s", statusIRI.String(), err)
}
@@ -104,7 +104,7 @@ func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI
// If we reach here, we're looking at a remote status -- make sure we have it in our db by calling GetRemoteStatus
// We call it with refresh to true because we want the statusable representation to parse inReplyTo from.
- status, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true)
+ _, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true, false, false)
if err != nil {
l.Debugf("error getting remote status: %s", err)
return nil
@@ -116,18 +116,6 @@ func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI
return nil
}
- // get the ancestor status into our database if we don't have it yet
- if _, _, _, err := d.GetRemoteStatus(ctx, username, inReplyTo, false); err != nil {
- l.Debugf("error getting remote status: %s", err)
- return nil
- }
-
- // now enrich the current status, since we should have the ancestor in the db
- if _, err := d.EnrichRemoteStatus(ctx, username, status); err != nil {
- l.Debugf("error enriching remote status: %s", err)
- return nil
- }
-
// now move up to the next ancestor
return d.iterateAncestors(ctx, username, *inReplyTo)
}
@@ -226,7 +214,7 @@ pageLoop:
foundReplies = foundReplies + 1
// get the remote statusable and put it in the db
- _, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false)
+ _, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false, false, false)
if new && err == nil && statusable != nil {
// now iterate descendants of *that* status
if err := d.iterateDescendants(ctx, username, *itemURI, statusable); err != nil {
diff --git a/internal/federation/federator.go b/internal/federation/federator.go
index 5eddcbb99..aecddf017 100644
--- a/internal/federation/federator.go
+++ b/internal/federation/federator.go
@@ -62,8 +62,8 @@ type Federator interface {
GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)
EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error)
- GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error)
- EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error)
+ GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error)
+ EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error)
GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error)
@@ -88,7 +88,6 @@ type federator struct {
// NewFederator returns a new federator
func NewFederator(db db.DB, federatingDB federatingdb.DB, transportController transport.Controller, config *config.Config, log *logrus.Logger, typeConverter typeutils.TypeConverter, mediaHandler media.Handler) Federator {
-
dereferencer := dereferencing.NewDereferencer(config, db, typeConverter, transportController, mediaHandler, log)
clock := &Clock{}
diff --git a/internal/federation/sqlite-test.db b/internal/federation/sqlite-test.db
deleted file mode 100644
index d34adbfe9..000000000
--- a/internal/federation/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/oauth/sqlite-test.db b/internal/oauth/sqlite-test.db
deleted file mode 100644
index 429e3d860..000000000
--- a/internal/oauth/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go
index 2bb74db34..cb0999cf9 100644
--- a/internal/processing/fromfederator.go
+++ b/internal/processing/fromfederator.go
@@ -49,7 +49,7 @@ func (p *processor) processFromFederator(ctx context.Context, federatorMsg gtsmo
return errors.New("note was not parseable as *gtsmodel.Status")
}
- status, err := p.federator.EnrichRemoteStatus(ctx, federatorMsg.ReceivingAccount.Username, incomingStatus)
+ status, err := p.federator.EnrichRemoteStatus(ctx, federatorMsg.ReceivingAccount.Username, incomingStatus, false, false)
if err != nil {
return err
}
diff --git a/internal/processing/search.go b/internal/processing/search.go
index 768fceacd..85da0d83f 100644
--- a/internal/processing/search.go
+++ b/internal/processing/search.go
@@ -130,7 +130,7 @@ func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, u
// we don't have it locally so dereference it if we're allowed to
if resolve {
- status, _, _, err := p.federator.GetRemoteStatus(ctx, authed.Account.Username, uri, true)
+ status, _, _, err := p.federator.GetRemoteStatus(ctx, authed.Account.Username, uri, true, false, false)
if err == nil {
if err := p.federator.DereferenceRemoteThread(ctx, authed.Account.Username, uri); err != nil {
// try to deref the thread while we're here
diff --git a/internal/processing/status/sqlite-test.db b/internal/processing/status/sqlite-test.db
deleted file mode 100644
index d266d6b1d..000000000
--- a/internal/processing/status/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/text/sqlite-test.db b/internal/text/sqlite-test.db
deleted file mode 100644
index 08b0a8909..000000000
--- a/internal/text/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/timeline/sqlite-test.db b/internal/timeline/sqlite-test.db
deleted file mode 100644
index 224027d43..000000000
--- a/internal/timeline/sqlite-test.db
+++ /dev/null
Binary files differ
diff --git a/internal/typeutils/astointernal.go b/internal/typeutils/astointernal.go
index 04d9cd824..4ba0df383 100644
--- a/internal/typeutils/astointernal.go
+++ b/internal/typeutils/astointernal.go
@@ -339,7 +339,6 @@ func (c *converter) ASStatusToStatus(ctx context.Context, statusable ap.Statusab
}
func (c *converter) ASFollowToFollowRequest(ctx context.Context, followable ap.Followable) (*gtsmodel.FollowRequest, error) {
-
idProp := followable.GetJSONLDId()
if idProp == nil || !idProp.IsIRI() {
return nil, errors.New("no id property set on follow, or was not an iri")
diff --git a/internal/typeutils/sqlite-test.db b/internal/typeutils/sqlite-test.db
deleted file mode 100644
index 2775172f1..000000000
--- a/internal/typeutils/sqlite-test.db
+++ /dev/null
Binary files differ