diff options
36 files changed, 653 insertions, 227 deletions
| diff --git a/internal/api/client/account/sqlite-test.db b/internal/api/client/account/sqlite-test.dbBinary files differ deleted file mode 100644 index eab8315d9..000000000 --- a/internal/api/client/account/sqlite-test.db +++ /dev/null diff --git a/internal/api/client/fileserver/sqlite-test.db b/internal/api/client/fileserver/sqlite-test.dbBinary files differ deleted file mode 100644 index 5689e7edb..000000000 --- a/internal/api/client/fileserver/sqlite-test.db +++ /dev/null diff --git a/internal/api/client/media/sqlite-test.db b/internal/api/client/media/sqlite-test.dbBinary files differ deleted file mode 100644 index 1ed985248..000000000 --- a/internal/api/client/media/sqlite-test.db +++ /dev/null diff --git a/internal/api/client/status/sqlite-test.db b/internal/api/client/status/sqlite-test.dbBinary files differ deleted file mode 100644 index 448d10813..000000000 --- a/internal/api/client/status/sqlite-test.db +++ /dev/null diff --git a/internal/api/s2s/user/sqlite-test.db b/internal/api/s2s/user/sqlite-test.dbBinary files differ deleted file mode 100644 index b67967b30..000000000 --- a/internal/api/s2s/user/sqlite-test.db +++ /dev/null diff --git a/internal/cache/account.go b/internal/cache/account.go new file mode 100644 index 000000000..bb402d60f --- /dev/null +++ b/internal/cache/account.go @@ -0,0 +1,157 @@ +package cache + +import ( +	"sync" + +	"github.com/ReneKroon/ttlcache" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// AccountCache is a wrapper around ttlcache.Cache to provide URL and URI lookups for gtsmodel.Account +type AccountCache struct { +	cache *ttlcache.Cache   // map of IDs -> cached accounts +	urls  map[string]string // map of account URLs -> IDs +	uris  map[string]string // map of account URIs -> IDs +	mutex sync.Mutex +} + +// NewAccountCache returns a new instantiated AccountCache object +func NewAccountCache() *AccountCache { +	c := AccountCache{ +		cache: ttlcache.NewCache(), +		urls:  make(map[string]string, 100), +		uris:  make(map[string]string, 100), +		mutex: sync.Mutex{}, +	} + +	// Set callback to purge lookup maps on expiration +	c.cache.SetExpirationCallback(func(key string, value interface{}) { +		account := value.(*gtsmodel.Account) + +		c.mutex.Lock() +		delete(c.urls, account.URL) +		delete(c.uris, account.URI) +		c.mutex.Unlock() +	}) + +	return &c +} + +// GetByID attempts to fetch a account from the cache by its ID, you will receive a copy for thread-safety +func (c *AccountCache) GetByID(id string) (*gtsmodel.Account, bool) { +	c.mutex.Lock() +	account, ok := c.getByID(id) +	c.mutex.Unlock() +	return account, ok +} + +// GetByURL attempts to fetch a account from the cache by its URL, you will receive a copy for thread-safety +func (c *AccountCache) GetByURL(url string) (*gtsmodel.Account, bool) { +	// Perform safe ID lookup +	c.mutex.Lock() +	id, ok := c.urls[url] + +	// Not found, unlock early +	if !ok { +		c.mutex.Unlock() +		return nil, false +	} + +	// Attempt account lookup +	account, ok := c.getByID(id) +	c.mutex.Unlock() +	return account, ok +} + +// GetByURI attempts to fetch a account from the cache by its URI, you will receive a copy for thread-safety +func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) { +	// Perform safe ID lookup +	c.mutex.Lock() +	id, ok := c.uris[uri] + +	// Not found, unlock early +	if !ok { +		c.mutex.Unlock() +		return nil, false +	} + +	// Attempt account lookup +	account, ok := c.getByID(id) +	c.mutex.Unlock() +	return account, ok +} + +// getByID performs an unsafe (no mutex locks) lookup of account by ID, returning a copy of account in cache +func (c *AccountCache) getByID(id string) (*gtsmodel.Account, bool) { +	v, ok := c.cache.Get(id) +	if !ok { +		return nil, false +	} +	return copyAccount(v.(*gtsmodel.Account)), true +} + +// Put places a account in the cache, ensuring that the object place is a copy for thread-safety +func (c *AccountCache) Put(account *gtsmodel.Account) { +	if account == nil || account.ID == "" { +		panic("invalid account") +	} + +	c.mutex.Lock() +	c.cache.Set(account.ID, copyAccount(account)) +	if account.URL != "" { +		c.urls[account.URL] = account.ID +	} +	if account.URI != "" { +		c.uris[account.URI] = account.ID +	} +	c.mutex.Unlock() +} + +// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects. +// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) +// this should be a relatively cheap process +func copyAccount(account *gtsmodel.Account) *gtsmodel.Account { +	return >smodel.Account{ +		ID:                      account.ID, +		Username:                account.Username, +		Domain:                  account.Domain, +		AvatarMediaAttachmentID: account.AvatarMediaAttachmentID, +		AvatarMediaAttachment:   nil, +		AvatarRemoteURL:         account.AvatarRemoteURL, +		HeaderMediaAttachmentID: account.HeaderMediaAttachmentID, +		HeaderMediaAttachment:   nil, +		HeaderRemoteURL:         account.HeaderRemoteURL, +		DisplayName:             account.DisplayName, +		Fields:                  account.Fields, +		Note:                    account.Note, +		Memorial:                account.Memorial, +		MovedToAccountID:        account.MovedToAccountID, +		CreatedAt:               account.CreatedAt, +		UpdatedAt:               account.UpdatedAt, +		Bot:                     account.Bot, +		Reason:                  account.Reason, +		Locked:                  account.Locked, +		Discoverable:            account.Discoverable, +		Privacy:                 account.Privacy, +		Sensitive:               account.Sensitive, +		Language:                account.Language, +		URI:                     account.URI, +		URL:                     account.URL, +		LastWebfingeredAt:       account.LastWebfingeredAt, +		InboxURI:                account.InboxURI, +		OutboxURI:               account.OutboxURI, +		FollowingURI:            account.FollowingURI, +		FollowersURI:            account.FollowersURI, +		FeaturedCollectionURI:   account.FeaturedCollectionURI, +		ActorType:               account.ActorType, +		AlsoKnownAs:             account.AlsoKnownAs, +		PrivateKey:              account.PrivateKey, +		PublicKey:               account.PublicKey, +		PublicKeyURI:            account.PublicKeyURI, +		SensitizedAt:            account.SensitizedAt, +		SilencedAt:              account.SilencedAt, +		SuspendedAt:             account.SuspendedAt, +		HideCollections:         account.HideCollections, +		SuspensionOrigin:        account.SuspensionOrigin, +	} +} diff --git a/internal/cache/account_test.go b/internal/cache/account_test.go new file mode 100644 index 000000000..f84ad2261 --- /dev/null +++ b/internal/cache/account_test.go @@ -0,0 +1,63 @@ +package cache_test + +import ( +	"testing" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/cache" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/testrig" +) + +type AccountCacheTestSuite struct { +	suite.Suite +	data  map[string]*gtsmodel.Account +	cache *cache.AccountCache +} + +func (suite *AccountCacheTestSuite) SetupSuite() { +	suite.data = testrig.NewTestAccounts() +} + +func (suite *AccountCacheTestSuite) SetupTest() { +	suite.cache = cache.NewAccountCache() +} + +func (suite *AccountCacheTestSuite) TearDownTest() { +	suite.data = nil +	suite.cache = nil +} + +func (suite *AccountCacheTestSuite) TestAccountCache() { +	for _, account := range suite.data { +		// Place in the cache +		suite.cache.Put(account) +	} + +	for _, account := range suite.data { +		var ok bool +		var check *gtsmodel.Account + +		// Check we can retrieve +		check, ok = suite.cache.GetByID(account.ID) +		if !ok && !accountIs(account, check) { +			suite.Fail("Failed to fetch expected account with ID: %s", account.ID) +		} +		check, ok = suite.cache.GetByURI(account.URI) +		if account.URI != "" && !ok && !accountIs(account, check) { +			suite.Fail("Failed to fetch expected account with URI: %s", account.URI) +		} +		check, ok = suite.cache.GetByURL(account.URL) +		if account.URL != "" && !ok && !accountIs(account, check) { +			suite.Fail("Failed to fetch expected account with URL: %s", account.URL) +		} +	} +} + +func TestAccountCache(t *testing.T) { +	suite.Run(t, &AccountCacheTestSuite{}) +} + +func accountIs(account1, account2 *gtsmodel.Account) bool { +	return account1.ID == account2.ID && account1.URI == account2.URI && account1.URL == account2.URL +} diff --git a/internal/cache/status.go b/internal/cache/status.go index 895a5692c..028abc8f7 100644 --- a/internal/cache/status.go +++ b/internal/cache/status.go @@ -37,7 +37,7 @@ func NewStatusCache() *StatusCache {  	return &c  } -// GetByID attempts to fetch a status from the cache by its ID +// GetByID attempts to fetch a status from the cache by its ID, you will receive a copy for thread-safety  func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) {  	c.mutex.Lock()  	status, ok := c.getByID(id) @@ -45,7 +45,7 @@ func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) {  	return status, ok  } -// GetByURL attempts to fetch a status from the cache by its URL +// GetByURL attempts to fetch a status from the cache by its URL, you will receive a copy for thread-safety  func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) {  	// Perform safe ID lookup  	c.mutex.Lock() @@ -63,7 +63,7 @@ func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) {  	return status, ok  } -// GetByURI attempts to fetch a status from the cache by its URI +// GetByURI attempts to fetch a status from the cache by its URI, you will receive a copy for thread-safety  func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) {  	// Perform safe ID lookup  	c.mutex.Lock() @@ -81,26 +81,72 @@ func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) {  	return status, ok  } -// getByID performs an unsafe (no mutex locks) lookup of status by ID +// getByID performs an unsafe (no mutex locks) lookup of status by ID, returning a copy of status in cache  func (c *StatusCache) getByID(id string) (*gtsmodel.Status, bool) {  	v, ok := c.cache.Get(id)  	if !ok {  		return nil, false  	} -	return v.(*gtsmodel.Status), true +	return copyStatus(v.(*gtsmodel.Status)), true  } -// Put places a status in the cache +// Put places a status in the cache, ensuring that the object place is a copy for thread-safety  func (c *StatusCache) Put(status *gtsmodel.Status) { -	if status == nil || status.ID == "" || -		status.URL == "" || -		status.URI == "" { +	if status == nil || status.ID == "" {  		panic("invalid status")  	}  	c.mutex.Lock() -	c.cache.Set(status.ID, status) -	c.urls[status.URL] = status.ID -	c.uris[status.URI] = status.ID +	c.cache.Set(status.ID, copyStatus(status)) +	if status.URL != "" { +		c.urls[status.URL] = status.ID +	} +	if status.URI != "" { +		c.uris[status.URI] = status.ID +	}  	c.mutex.Unlock()  } + +// copyStatus performs a surface-level copy of status, only keeping attached IDs intact, not the objects. +// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) +// this should be a relatively cheap process +func copyStatus(status *gtsmodel.Status) *gtsmodel.Status { +	return >smodel.Status{ +		ID:                       status.ID, +		URI:                      status.URI, +		URL:                      status.URL, +		Content:                  status.Content, +		AttachmentIDs:            status.AttachmentIDs, +		Attachments:              nil, +		TagIDs:                   status.TagIDs, +		Tags:                     nil, +		MentionIDs:               status.MentionIDs, +		Mentions:                 nil, +		EmojiIDs:                 status.EmojiIDs, +		Emojis:                   nil, +		CreatedAt:                status.CreatedAt, +		UpdatedAt:                status.UpdatedAt, +		Local:                    status.Local, +		AccountID:                status.AccountID, +		Account:                  nil, +		AccountURI:               status.AccountURI, +		InReplyToID:              status.InReplyToID, +		InReplyTo:                nil, +		InReplyToURI:             status.InReplyToURI, +		InReplyToAccountID:       status.InReplyToAccountID, +		InReplyToAccount:         nil, +		BoostOfID:                status.BoostOfID, +		BoostOf:                  nil, +		BoostOfAccountID:         status.BoostOfAccountID, +		BoostOfAccount:           nil, +		ContentWarning:           status.ContentWarning, +		Visibility:               status.Visibility, +		Sensitive:                status.Sensitive, +		Language:                 status.Language, +		CreatedWithApplicationID: status.CreatedWithApplicationID, +		VisibilityAdvanced:       status.VisibilityAdvanced, +		ActivityStreamsType:      status.ActivityStreamsType, +		Text:                     status.Text, +		Pinned:                   status.Pinned, +	} +} diff --git a/internal/cache/status_test.go b/internal/cache/status_test.go index 10dee5bca..222961025 100644 --- a/internal/cache/status_test.go +++ b/internal/cache/status_test.go @@ -3,39 +3,61 @@ package cache_test  import (  	"testing" +	"github.com/stretchr/testify/suite"  	"github.com/superseriousbusiness/gotosocial/internal/cache"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/testrig"  ) -func TestStatusCache(t *testing.T) { -	cache := cache.NewStatusCache() +type StatusCacheTestSuite struct { +	suite.Suite +	data  map[string]*gtsmodel.Status +	cache *cache.StatusCache +} -	// Attempt to place a status -	status := gtsmodel.Status{ -		ID:  "id", -		URI: "uri", -		URL: "url", -	} -	cache.Put(&status) +func (suite *StatusCacheTestSuite) SetupSuite() { +	suite.data = testrig.NewTestStatuses() +} -	var ok bool -	var check *gtsmodel.Status +func (suite *StatusCacheTestSuite) SetupTest() { +	suite.cache = cache.NewStatusCache() +} -	// Check we can retrieve -	check, ok = cache.GetByID(status.ID) -	if !ok || !statusIs(&status, check) { -		t.Fatal("Could not find expected status") -	} -	check, ok = cache.GetByURI(status.URI) -	if !ok || !statusIs(&status, check) { -		t.Fatal("Could not find expected status") +func (suite *StatusCacheTestSuite) TearDownTest() { +	suite.data = nil +	suite.cache = nil +} + +func (suite *StatusCacheTestSuite) TestStatusCache() { +	for _, status := range suite.data { +		// Place in the cache +		suite.cache.Put(status)  	} -	check, ok = cache.GetByURL(status.URL) -	if !ok || !statusIs(&status, check) { -		t.Fatal("Could not find expected status") + +	for _, status := range suite.data { +		var ok bool +		var check *gtsmodel.Status + +		// Check we can retrieve +		check, ok = suite.cache.GetByID(status.ID) +		if !ok && !statusIs(status, check) { +			suite.Fail("Failed to fetch expected account with ID: %s", status.ID) +		} +		check, ok = suite.cache.GetByURI(status.URI) +		if status.URI != "" && !ok && !statusIs(status, check) { +			suite.Fail("Failed to fetch expected account with URI: %s", status.URI) +		} +		check, ok = suite.cache.GetByURL(status.URL) +		if status.URL != "" && !ok && !statusIs(status, check) { +			suite.Fail("Failed to fetch expected account with URL: %s", status.URL) +		}  	}  } +func TestStatusCache(t *testing.T) { +	suite.Run(t, &StatusCacheTestSuite{}) +} +  func statusIs(status1, status2 *gtsmodel.Status) bool {  	return status1.ID == status2.ID && status1.URI == status2.URI && status1.URL == status2.URL  } diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index d7d45a739..32a70f7cd 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -25,6 +25,7 @@ import (  	"strings"  	"time" +	"github.com/superseriousbusiness/gotosocial/internal/cache"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -34,6 +35,7 @@ import (  type accountDB struct {  	config *config.Config  	conn   *DBConn +	cache  *cache.AccountCache  }  func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { @@ -45,60 +47,80 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {  }  func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { -	account := new(gtsmodel.Account) - -	q := a.newAccountQ(account). -		Where("account.id = ?", id) - -	err := q.Scan(ctx) -	if err != nil { -		return nil, a.conn.ProcessError(err) -	} -	return account, nil +	return a.getAccount( +		ctx, +		func() (*gtsmodel.Account, bool) { +			return a.cache.GetByID(id) +		}, +		func(account *gtsmodel.Account) error { +			return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx) +		}, +	)  }  func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { -	account := new(gtsmodel.Account) - -	q := a.newAccountQ(account). -		Where("account.uri = ?", uri) +	return a.getAccount( +		ctx, +		func() (*gtsmodel.Account, bool) { +			return a.cache.GetByURI(uri) +		}, +		func(account *gtsmodel.Account) error { +			return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx) +		}, +	) +} -	err := q.Scan(ctx) -	if err != nil { -		return nil, a.conn.ProcessError(err) -	} -	return account, nil +func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) { +	return a.getAccount( +		ctx, +		func() (*gtsmodel.Account, bool) { +			return a.cache.GetByURL(url) +		}, +		func(account *gtsmodel.Account) error { +			return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx) +		}, +	)  } -func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { -	account := new(gtsmodel.Account) +func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { +	// Attempt to fetch cached account +	account, cached := cacheGet() -	q := a.newAccountQ(account). -		Where("account.url = ?", uri) +	if !cached { +		account = >smodel.Account{} -	err := q.Scan(ctx) -	if err != nil { -		return nil, a.conn.ProcessError(err) +		// Not cached! Perform database query +		err := dbQuery(account) +		if err != nil { +			return nil, a.conn.ProcessError(err) +		} + +		// Place in the cache +		a.cache.Put(account)  	} +  	return account, nil  }  func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {  	if strings.TrimSpace(account.ID) == "" { +		// TODO: we should not need this check here  		return nil, errors.New("account had no ID")  	} +	// Update the account's last-used  	account.UpdatedAt = time.Now() -	q := a.conn. -		NewUpdate(). -		Model(account). -		WherePK() - -	_, err := q.Exec(ctx) +	// Update the account model in the DB +	_, err := a.conn.NewUpdate().Model(account).WherePK().Exec(ctx)  	if err != nil {  		return nil, a.conn.ProcessError(err)  	} + +	// Place updated account in cache +	// (this will replace existing, i.e. invalidating) +	a.cache.Put(account) +  	return account, nil  } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 248232fe3..6fcc56e51 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -91,6 +91,15 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)  		conn = WrapDBConn(bun.NewDB(sqldb, pgdialect.New()), log)  	case dbTypeSqlite:  		// SQLITE + +		// Drop anything fancy from DB address +		c.DBConfig.Address = strings.Split(c.DBConfig.Address, "?")[0] +		c.DBConfig.Address = strings.TrimPrefix(c.DBConfig.Address, "file:") + +		// Append our own SQLite preferences +		c.DBConfig.Address = "file:" + c.DBConfig.Address + "?cache=shared" + +		// Open new DB instance  		var err error  		sqldb, err = sql.Open("sqlite", c.DBConfig.Address)  		if err != nil { @@ -98,7 +107,7 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)  		}  		conn = WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()), log) -		if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") { +		if c.DBConfig.Address == "file::memory:?cache=shared" {  			log.Warn("sqlite in-memory database should only be used for debugging")  			// don't close connections on disconnect -- otherwise @@ -121,11 +130,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)  		conn.RegisterModel(t)  	} +	accounts := &accountDB{config: c, conn: conn, cache: cache.NewAccountCache()} +  	ps := &bunDBService{ -		Account: &accountDB{ -			config: c, -			conn:   conn, -		}, +		Account: accounts,  		Admin: &adminDB{  			config: c,  			conn:   conn, @@ -165,9 +173,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)  			conn:   conn,  		},  		Status: &statusDB{ -			config: c, -			conn:   conn, -			cache:  cache.NewStatusCache(), +			config:   c, +			conn:     conn, +			cache:    cache.NewStatusCache(), +			accounts: accounts,  		},  		Timeline: &timelineDB{  			config: c, diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go index 698adff3d..abaebcebd 100644 --- a/internal/db/bundb/conn.go +++ b/internal/db/bundb/conn.go @@ -12,6 +12,8 @@ import (  // dbConn wrapps a bun.DB conn to provide SQL-type specific additional functionality  type DBConn struct { +	// TODO: move *Config here, no need to be in each struct type +  	errProc func(error) db.Error // errProc is the SQL-type specific error processor  	log     *logrus.Logger       // log is the logger passed with this DBConn  	*bun.DB                      // DB is the underlying bun.DB connection @@ -35,6 +37,24 @@ func WrapDBConn(dbConn *bun.DB, log *logrus.Logger) *DBConn {  	}  } +func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error { +	// Acquire a new transaction +	tx, err := conn.BeginTx(ctx, nil) +	if err != nil { +		return conn.ProcessError(err) +	} + +	// Perform supplied transaction +	if err = fn(tx); err != nil { +		tx.Rollback() //nolint +		return conn.ProcessError(err) +	} + +	// Finally, commit transaction +	err = tx.Commit() +	return conn.ProcessError(err) +} +  // ProcessError processes an error to replace any known values with our own db.Error types,  // making it easier to catch specific situations (e.g. no rows, already exists, etc)  func (conn *DBConn) ProcessError(err error) db.Error { diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 56b752593..64d896527 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -237,7 +237,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI  	if _, err := r.conn.  		NewInsert().  		Model(follow). -		On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI). +		On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI).  		Exec(ctx); err != nil {  		return nil, r.conn.ProcessError(err)  	} @@ -298,7 +298,7 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str  	if localOnly {  		q = q.ColumnExpr("follow.*"). -			Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)"). +			Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)").  			Where("follow.target_account_id = ?", accountID).  			WhereGroup(" AND ", whereEmptyOrNull("a.domain"))  	} else { diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go new file mode 100644 index 000000000..dcc71b37c --- /dev/null +++ b/internal/db/bundb/relationship_test.go @@ -0,0 +1,124 @@ +/* +   GoToSocial +   Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( +	"context" +	"errors" +	"testing" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/testrig" +) + +type RelationshipTestSuite struct { +	BunDBStandardTestSuite +} + +func (suite *RelationshipTestSuite) SetupSuite() { +	suite.testTokens = testrig.NewTestTokens() +	suite.testClients = testrig.NewTestClients() +	suite.testApplications = testrig.NewTestApplications() +	suite.testUsers = testrig.NewTestUsers() +	suite.testAccounts = testrig.NewTestAccounts() +	suite.testAttachments = testrig.NewTestAttachments() +	suite.testStatuses = testrig.NewTestStatuses() +	suite.testTags = testrig.NewTestTags() +	suite.testMentions = testrig.NewTestMentions() +} + +func (suite *RelationshipTestSuite) SetupTest() { +	suite.config = testrig.NewTestConfig() +	suite.db = testrig.NewTestDB() +	suite.log = testrig.NewTestLog() + +	testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *RelationshipTestSuite) TearDownTest() { +	testrig.StandardDBTeardown(suite.db) +} + +func (suite *RelationshipTestSuite) TestIsBlocked() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestGetBlock() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestGetRelationship() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestIsFollowing() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestIsMutualFollowing() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) AcceptFollowRequest() { +	for _, account := range suite.testAccounts { +		_, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID") +		if err != nil && !errors.Is(err, db.ErrNoEntries) { +			suite.Suite.Fail("error accepting follow request: %v", err) +		} +	} +} + +func (suite *RelationshipTestSuite) GetAccountFollowRequests() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) GetAccountFollows() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) CountAccountFollows() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) GetAccountFollowedBy() { +	// TODO: more comprehensive tests here + +	for _, account := range suite.testAccounts { +		var err error + +		_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false) +		if err != nil { +			suite.Suite.Fail("error checking accounts followed by: %v", err) +		} + +		_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) +		if err != nil { +			suite.Suite.Fail("error checking localOnly accounts followed by: %v", err) +		} +	} +} + +func (suite *RelationshipTestSuite) CountAccountFollowedBy() { +	suite.Suite.T().Skip("TODO: implement") +} + +func TestRelationshipTestSuite(t *testing.T) { +	suite.Run(t, new(RelationshipTestSuite)) +} diff --git a/internal/db/bundb/sqlite-test.db b/internal/db/bundb/sqlite-test.dbBinary files differ deleted file mode 100644 index ed3b25ee3..000000000 --- a/internal/db/bundb/sqlite-test.db +++ /dev/null diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 1d5acf0fc..9464cfadf 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -21,7 +21,6 @@ package bundb  import (  	"container/list"  	"context" -	"errors"  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/cache" @@ -35,6 +34,11 @@ type statusDB struct {  	config *config.Config  	conn   *DBConn  	cache  *cache.StatusCache + +	// TODO: keep method definitions in same place but instead have receiver +	//       all point to one single "db" type, so they can all share methods +	//       and caches where necessary +	accounts *accountDB  }  func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { @@ -51,30 +55,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {  		Relation("CreatedWithApplication")  } -func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status { -	if status.InReplyToID != "" && status.InReplyTo == nil { -		// TODO: do we want to keep this possibly recursive strategy? - -		if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached { -			status.InReplyTo = inReplyTo -		} else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil { -			status.InReplyTo = inReplyTo -		} -	} - -	if status.BoostOfID != "" && status.BoostOf == nil { -		// TODO: do we want to keep this possibly recursive strategy? - -		if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached { -			status.BoostOf = boostOf -		} else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil { -			status.BoostOf = boostOf -		} -	} - -	return status -} -  func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {  	return s.conn.  		NewSelect(). @@ -85,64 +65,79 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {  }  func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { -	if status, cached := s.cache.GetByID(id); cached { -		return status, nil -	} - -	status := >smodel.Status{} - -	q := s.newStatusQ(status). -		Where("status.id = ?", id) - -	err := q.Scan(ctx) -	if err != nil { -		return nil, s.conn.ProcessError(err) -	} - -	s.cache.Put(status) -	return s.getAttachedStatuses(ctx, status), nil +	return s.getStatus( +		ctx, +		func() (*gtsmodel.Status, bool) { +			return s.cache.GetByID(id) +		}, +		func(status *gtsmodel.Status) error { +			return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx) +		}, +	)  }  func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { -	if status, cached := s.cache.GetByURI(uri); cached { -		return status, nil -	} - -	status := >smodel.Status{} +	return s.getStatus( +		ctx, +		func() (*gtsmodel.Status, bool) { +			return s.cache.GetByURI(uri) +		}, +		func(status *gtsmodel.Status) error { +			return s.newStatusQ(status).Where("LOWER(status.uri) = LOWER(?)", uri).Scan(ctx) +		}, +	) +} -	q := s.newStatusQ(status). -		Where("LOWER(status.uri) = LOWER(?)", uri) +func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { +	return s.getStatus( +		ctx, +		func() (*gtsmodel.Status, bool) { +			return s.cache.GetByURL(url) +		}, +		func(status *gtsmodel.Status) error { +			return s.newStatusQ(status).Where("LOWER(status.url) = LOWER(?)", url).Scan(ctx) +		}, +	) +} -	err := q.Scan(ctx) -	if err != nil { -		return nil, s.conn.ProcessError(err) -	} +func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) { +	// Attempt to fetch cached status +	status, cached := cacheGet() -	s.cache.Put(status) -	return s.getAttachedStatuses(ctx, status), nil -} +	if !cached { +		status = >smodel.Status{} -func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { -	if status, cached := s.cache.GetByURL(url); cached { -		return status, nil -	} +		// Not cached! Perform database query +		err := dbQuery(status) +		if err != nil { +			return nil, s.conn.ProcessError(err) +		} -	status := >smodel.Status{} +		// If there is boosted, fetch from DB also +		if status.BoostOfID != "" { +			boostOf, err := s.GetStatusByID(ctx, status.BoostOfID) +			if err == nil { +				status.BoostOf = boostOf +			} +		} -	q := s.newStatusQ(status). -		Where("LOWER(status.url) = LOWER(?)", url) +		// Place in the cache +		s.cache.Put(status) +	} -	err := q.Scan(ctx) +	// Set the status author account +	author, err := s.accounts.GetAccountByID(ctx, status.AccountID)  	if err != nil { -		return nil, s.conn.ProcessError(err) +		return nil, err  	} -	s.cache.Put(status) -	return s.getAttachedStatuses(ctx, status), nil +	// Return the prepared status +	status.Account = author +	return status, nil  }  func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { -	transaction := func(ctx context.Context, tx bun.Tx) error { +	return s.conn.RunInTx(ctx, func(tx bun.Tx) error {  		// create links between this status and any emojis it uses  		for _, i := range status.EmojiIDs {  			if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ @@ -174,10 +169,10 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er  			}  		} +		// Finally, insert the status  		_, err := tx.NewInsert().Model(status).Exec(ctx)  		return err -	} -	return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction)) +	})  }  func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { @@ -210,12 +205,8 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu  	children := []*gtsmodel.Status{}  	for e := foundStatuses.Front(); e != nil; e = e.Next() { -		entry, ok := e.Value.(*gtsmodel.Status) -		if !ok { -			panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) -		} -  		// only append children, not the overall parent status +		entry := e.Value.(*gtsmodel.Status)  		if entry.ID != status.ID {  			children = append(children, entry)  		} @@ -242,11 +233,7 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,  	for _, child := range immediateChildren {  	insertLoop:  		for e := foundStatuses.Front(); e != nil; e = e.Next() { -			entry, ok := e.Value.(*gtsmodel.Status) -			if !ok { -				panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) -			} - +			entry := e.Value.(*gtsmodel.Status)  			if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {  				foundStatuses.InsertAfter(child, e)  				break insertLoop diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go index 4f846441b..7acc86ff9 100644 --- a/internal/db/bundb/status_test.go +++ b/internal/db/bundb/status_test.go @@ -105,10 +105,9 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {  	suite.NotNil(status)  	suite.NotNil(status.Account)  	suite.NotNil(status.CreatedWithApplication) -	suite.NotEmpty(status.Mentions)  	suite.NotEmpty(status.MentionIDs) -	suite.NotNil(status.InReplyTo) -	suite.NotNil(status.InReplyToAccount) +	suite.NotEmpty(status.InReplyToID) +	suite.NotEmpty(status.InReplyToAccountID)  }  func (suite *StatusTestSuite) TestGetStatusTwice() { diff --git a/internal/db/status.go b/internal/db/status.go index 7430433c4..f26f8942e 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -26,13 +26,13 @@ import (  // Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.  type Status interface { -	// GetStatusByID returns one status from the database, with all rel fields populated (if possible). +	// GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs  	GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error) -	// GetStatusByURI returns one status from the database, with all rel fields populated (if possible). +	// GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs  	GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error) -	// GetStatusByURL returns one status from the database, with all rel fields populated (if possible). +	// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs  	GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)  	// PutStatus stores one status in the database. diff --git a/internal/federation/dereference.go b/internal/federation/dereference.go index a09f0f84b..a9dbabb42 100644 --- a/internal/federation/dereference.go +++ b/internal/federation/dereference.go @@ -34,12 +34,12 @@ func (f *federator) EnrichRemoteAccount(ctx context.Context, username string, ac  	return f.dereferencer.EnrichRemoteAccount(ctx, username, account)  } -func (f *federator) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) { -	return f.dereferencer.GetRemoteStatus(ctx, username, remoteStatusID, refresh) +func (f *federator) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) { +	return f.dereferencer.GetRemoteStatus(ctx, username, remoteStatusID, refresh, includeParent, includeChilds)  } -func (f *federator) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) { -	return f.dereferencer.EnrichRemoteStatus(ctx, username, status) +func (f *federator) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error) { +	return f.dereferencer.EnrichRemoteStatus(ctx, username, status, includeParent, includeChilds)  }  func (f *federator) DereferenceRemoteThread(ctx context.Context, username string, statusIRI *url.URL) error { diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go index 2eee0645d..8cae002e8 100644 --- a/internal/federation/dereferencing/account.go +++ b/internal/federation/dereferencing/account.go @@ -48,7 +48,6 @@ func instanceAccount(account *gtsmodel.Account) bool {  // EnrichRemoteAccount is mostly useful for calling after an account has been initially created by  // the federatingDB's Create function, or during the federated authorization flow.  func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) { -  	// if we're dealing with an instance account, we don't need to update anything  	if instanceAccount(account) {  		return account, nil @@ -58,13 +57,13 @@ func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, accoun  		return nil, err  	} -	var err error -	account, err = d.db.UpdateAccount(ctx, account) +	updated, err := d.db.UpdateAccount(ctx, account)  	if err != nil {  		d.log.Errorf("EnrichRemoteAccount: error updating account: %s", err) +		return account, nil  	} -	return account, nil +	return updated, nil  }  // GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account, diff --git a/internal/federation/dereferencing/announce.go b/internal/federation/dereferencing/announce.go index 33af74ebe..d5cc5ad0c 100644 --- a/internal/federation/dereferencing/announce.go +++ b/internal/federation/dereferencing/announce.go @@ -46,7 +46,7 @@ func (d *deref) DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Stat  		return fmt.Errorf("DereferenceAnnounce: error dereferencing thread of boosted status: %s", err)  	} -	boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false) +	boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false, false, false)  	if err != nil {  		return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.BoostOf.URI, err)  	} diff --git a/internal/federation/dereferencing/dereferencer.go b/internal/federation/dereferencing/dereferencer.go index 4191bd283..8ad21013f 100644 --- a/internal/federation/dereferencing/dereferencer.go +++ b/internal/federation/dereferencing/dereferencer.go @@ -38,8 +38,8 @@ type Dereferencer interface {  	GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)  	EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) -	GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) -	EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) +	GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) +	EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error)  	GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) diff --git a/internal/federation/dereferencing/sqlite-test.db b/internal/federation/dereferencing/sqlite-test.dbBinary files differ deleted file mode 100644 index bef45b3af..000000000 --- a/internal/federation/dereferencing/sqlite-test.db +++ /dev/null diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go index 3fa1e4133..7a7f928f1 100644 --- a/internal/federation/dereferencing/status.go +++ b/internal/federation/dereferencing/status.go @@ -39,8 +39,8 @@ import (  //  // EnrichRemoteStatus is mostly useful for calling after a status has been initially created by  // the federatingDB's Create function, but additional dereferencing is needed on it. -func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) { -	if err := d.populateStatusFields(ctx, status, username); err != nil { +func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error) { +	if err := d.populateStatusFields(ctx, status, username, includeParent, includeChilds); err != nil {  		return nil, err  	} @@ -62,7 +62,7 @@ func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status  // If a dereference was performed, then the function also returns the ap.Statusable representation for further processing.  //  // SIDE EFFECTS: remote status will be stored in the database, and the remote status owner will also be stored. -func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) { +func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) {  	new := true  	// check if we already have the status in our db @@ -105,7 +105,7 @@ func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStat  		}  		gtsStatus.ID = ulid -		if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil { +		if err := d.populateStatusFields(ctx, gtsStatus, username, includeParent, includeChilds); err != nil {  			return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)  		} @@ -115,7 +115,7 @@ func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStat  	} else {  		gtsStatus.ID = maybeStatus.ID -		if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil { +		if err := d.populateStatusFields(ctx, gtsStatus, username, includeParent, includeChilds); err != nil {  			return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)  		} @@ -235,7 +235,7 @@ func (d *deref) dereferenceStatusable(ctx context.Context, username string, remo  // This function will deference all of the above, insert them in the database as necessary,  // and attach them to the status. The status itself will not be added to the database yet,  // that's up the caller to do. -func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string) error { +func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string, includeParent, includeChilds bool) error {  	l := d.log.WithFields(logrus.Fields{  		"func":   "dereferenceStatusFields",  		"status": fmt.Sprintf("%+v", status), @@ -275,14 +275,19 @@ func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Statu  	// 3. Emojis  	// TODO -	// 4. Mentions -	if err := d.populateStatusMentions(ctx, status, requestingUsername); err != nil { -		return fmt.Errorf("populateStatusFields: error populating status mentions: %s", err) +	// 4. Mentions (only if requested) +	// TODO: do we need to handle removing empty mention objects and just using mention IDs slice? +	if includeChilds { +		if err := d.populateStatusMentions(ctx, status, requestingUsername); err != nil { +			return fmt.Errorf("populateStatusFields: error populating status mentions: %s", err) +		}  	} -	// 5. Replied-to-status. -	if err := d.populateStatusRepliedTo(ctx, status, requestingUsername); err != nil { -		return fmt.Errorf("populateStatusFields: error populating status repliedTo: %s", err) +	// 5. Replied-to-status (only if requested) +	if includeParent { +		if err := d.populateStatusRepliedTo(ctx, status, requestingUsername); err != nil { +			return fmt.Errorf("populateStatusFields: error populating status repliedTo: %s", err) +		}  	}  	return nil @@ -391,7 +396,6 @@ func (d *deref) populateStatusAttachments(ctx context.Context, status *gtsmodel.  	attachments := []*gtsmodel.MediaAttachment{}  	for _, a := range status.Attachments { -  		aURL, err := url.Parse(a.RemoteURL)  		if err != nil {  			l.Errorf("populateStatusAttachments: couldn't parse attachment url %s: %s", a.RemoteURL, err) @@ -401,6 +405,7 @@ func (d *deref) populateStatusAttachments(ctx context.Context, status *gtsmodel.  		attachment, err := d.GetRemoteAttachment(ctx, requestingUsername, aURL, status.AccountID, status.ID, a.File.ContentType)  		if err != nil {  			l.Errorf("populateStatusAttachments: couldn't get remote attachment %s: %s", a.RemoteURL, err) +			continue  		}  		attachmentIDs = append(attachmentIDs, attachment.ID) @@ -420,29 +425,16 @@ func (d *deref) populateStatusRepliedTo(ctx context.Context, status *gtsmodel.St  			return err  		} -		var replyToStatus *gtsmodel.Status -		errs := []string{} -  		// see if we have the status in our db already -		if s, err := d.db.GetStatusByURI(ctx, status.InReplyToURI); err != nil { -			errs = append(errs, err.Error()) -		} else { -			replyToStatus = s -		} - -		if replyToStatus == nil { -			// didn't find the status in our db, try to get it remotely -			if s, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, statusURI, false); err != nil { -				errs = append(errs, err.Error()) -			} else { -				replyToStatus = s +		replyToStatus, err := d.db.GetStatusByURI(ctx, status.InReplyToURI) +		if err != nil { +			// Status was not in the DB, try fetch +			replyToStatus, _, _, err = d.GetRemoteStatus(ctx, requestingUsername, statusURI, false, false, false) +			if err != nil { +				return fmt.Errorf("populateStatusRepliedTo: couldn't get reply to status with uri %s: %s", status.InReplyToURI, err)  			}  		} -		if replyToStatus == nil { -			return fmt.Errorf("populateStatusRepliedTo: couldn't get reply to status with uri %s: %s", statusURI, strings.Join(errs, " : ")) -		} -  		// we have the status  		status.InReplyToID = replyToStatus.ID  		status.InReplyTo = replyToStatus diff --git a/internal/federation/dereferencing/status_test.go b/internal/federation/dereferencing/status_test.go index 2d259682b..43732ac77 100644 --- a/internal/federation/dereferencing/status_test.go +++ b/internal/federation/dereferencing/status_test.go @@ -119,7 +119,7 @@ func (suite *StatusTestSuite) TestDereferenceSimpleStatus() {  	fetchingAccount := suite.testAccounts["local_account_1"]  	statusURL := testrig.URLMustParse("https://unknown-instance.com/users/brand_new_person/statuses/01FE4NTHKWW7THT67EF10EB839") -	status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false) +	status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false, false, false)  	suite.NoError(err)  	suite.NotNil(status)  	suite.NotNil(statusable) @@ -157,7 +157,7 @@ func (suite *StatusTestSuite) TestDereferenceStatusWithMention() {  	fetchingAccount := suite.testAccounts["local_account_1"]  	statusURL := testrig.URLMustParse("https://unknown-instance.com/users/brand_new_person/statuses/01FE5Y30E3W4P7TRE0R98KAYQV") -	status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false) +	status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false, false, true)  	suite.NoError(err)  	suite.NotNil(status)  	suite.NotNil(statusable) diff --git a/internal/federation/dereferencing/thread.go b/internal/federation/dereferencing/thread.go index f9dd9aa09..af16c01b2 100644 --- a/internal/federation/dereferencing/thread.go +++ b/internal/federation/dereferencing/thread.go @@ -49,7 +49,7 @@ func (d *deref) DereferenceThread(ctx context.Context, username string, statusIR  	}  	// first make sure we have this status in our db -	_, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true) +	_, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true, false, false)  	if err != nil {  		return fmt.Errorf("DereferenceThread: error getting status with id %s: %s", statusIRI.String(), err)  	} @@ -104,7 +104,7 @@ func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI  	// If we reach here, we're looking at a remote status -- make sure we have it in our db by calling GetRemoteStatus  	// We call it with refresh to true because we want the statusable representation to parse inReplyTo from. -	status, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true) +	_, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true, false, false)  	if err != nil {  		l.Debugf("error getting remote status: %s", err)  		return nil @@ -116,18 +116,6 @@ func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI  		return nil  	} -	// get the ancestor status into our database if we don't have it yet -	if _, _, _, err := d.GetRemoteStatus(ctx, username, inReplyTo, false); err != nil { -		l.Debugf("error getting remote status: %s", err) -		return nil -	} - -	// now enrich the current status, since we should have the ancestor in the db -	if _, err := d.EnrichRemoteStatus(ctx, username, status); err != nil { -		l.Debugf("error enriching remote status: %s", err) -		return nil -	} -  	// now move up to the next ancestor  	return d.iterateAncestors(ctx, username, *inReplyTo)  } @@ -226,7 +214,7 @@ pageLoop:  			foundReplies = foundReplies + 1  			// get the remote statusable and put it in the db -			_, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false) +			_, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false, false, false)  			if new && err == nil && statusable != nil {  				// now iterate descendants of *that* status  				if err := d.iterateDescendants(ctx, username, *itemURI, statusable); err != nil { diff --git a/internal/federation/federator.go b/internal/federation/federator.go index 5eddcbb99..aecddf017 100644 --- a/internal/federation/federator.go +++ b/internal/federation/federator.go @@ -62,8 +62,8 @@ type Federator interface {  	GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)  	EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) -	GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) -	EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) +	GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) +	EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error)  	GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error) @@ -88,7 +88,6 @@ type federator struct {  // NewFederator returns a new federator  func NewFederator(db db.DB, federatingDB federatingdb.DB, transportController transport.Controller, config *config.Config, log *logrus.Logger, typeConverter typeutils.TypeConverter, mediaHandler media.Handler) Federator { -  	dereferencer := dereferencing.NewDereferencer(config, db, typeConverter, transportController, mediaHandler, log)  	clock := &Clock{} diff --git a/internal/federation/sqlite-test.db b/internal/federation/sqlite-test.dbBinary files differ deleted file mode 100644 index d34adbfe9..000000000 --- a/internal/federation/sqlite-test.db +++ /dev/null diff --git a/internal/oauth/sqlite-test.db b/internal/oauth/sqlite-test.dbBinary files differ deleted file mode 100644 index 429e3d860..000000000 --- a/internal/oauth/sqlite-test.db +++ /dev/null diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go index 2bb74db34..cb0999cf9 100644 --- a/internal/processing/fromfederator.go +++ b/internal/processing/fromfederator.go @@ -49,7 +49,7 @@ func (p *processor) processFromFederator(ctx context.Context, federatorMsg gtsmo  				return errors.New("note was not parseable as *gtsmodel.Status")  			} -			status, err := p.federator.EnrichRemoteStatus(ctx, federatorMsg.ReceivingAccount.Username, incomingStatus) +			status, err := p.federator.EnrichRemoteStatus(ctx, federatorMsg.ReceivingAccount.Username, incomingStatus, false, false)  			if err != nil {  				return err  			} diff --git a/internal/processing/search.go b/internal/processing/search.go index 768fceacd..85da0d83f 100644 --- a/internal/processing/search.go +++ b/internal/processing/search.go @@ -130,7 +130,7 @@ func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, u  	// we don't have it locally so dereference it if we're allowed to  	if resolve { -		status, _, _, err := p.federator.GetRemoteStatus(ctx, authed.Account.Username, uri, true) +		status, _, _, err := p.federator.GetRemoteStatus(ctx, authed.Account.Username, uri, true, false, false)  		if err == nil {  			if err := p.federator.DereferenceRemoteThread(ctx, authed.Account.Username, uri); err != nil {  				// try to deref the thread while we're here diff --git a/internal/processing/status/sqlite-test.db b/internal/processing/status/sqlite-test.dbBinary files differ deleted file mode 100644 index d266d6b1d..000000000 --- a/internal/processing/status/sqlite-test.db +++ /dev/null diff --git a/internal/text/sqlite-test.db b/internal/text/sqlite-test.dbBinary files differ deleted file mode 100644 index 08b0a8909..000000000 --- a/internal/text/sqlite-test.db +++ /dev/null diff --git a/internal/timeline/sqlite-test.db b/internal/timeline/sqlite-test.dbBinary files differ deleted file mode 100644 index 224027d43..000000000 --- a/internal/timeline/sqlite-test.db +++ /dev/null diff --git a/internal/typeutils/astointernal.go b/internal/typeutils/astointernal.go index 04d9cd824..4ba0df383 100644 --- a/internal/typeutils/astointernal.go +++ b/internal/typeutils/astointernal.go @@ -339,7 +339,6 @@ func (c *converter) ASStatusToStatus(ctx context.Context, statusable ap.Statusab  }  func (c *converter) ASFollowToFollowRequest(ctx context.Context, followable ap.Followable) (*gtsmodel.FollowRequest, error) { -  	idProp := followable.GetJSONLDId()  	if idProp == nil || !idProp.IsIRI() {  		return nil, errors.New("no id property set on follow, or was not an iri") diff --git a/internal/typeutils/sqlite-test.db b/internal/typeutils/sqlite-test.dbBinary files differ deleted file mode 100644 index 2775172f1..000000000 --- a/internal/typeutils/sqlite-test.db +++ /dev/null | 
