diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/bundb/account.go | 86 | ||||
| -rw-r--r-- | internal/db/bundb/bundb.go | 25 | ||||
| -rw-r--r-- | internal/db/bundb/conn.go | 20 | ||||
| -rw-r--r-- | internal/db/bundb/relationship.go | 4 | ||||
| -rw-r--r-- | internal/db/bundb/relationship_test.go | 124 | ||||
| -rw-r--r-- | internal/db/bundb/sqlite-test.db | bin | 315392 -> 0 bytes | |||
| -rw-r--r-- | internal/db/bundb/status.go | 145 | ||||
| -rw-r--r-- | internal/db/bundb/status_test.go | 5 | ||||
| -rw-r--r-- | internal/db/status.go | 6 | 
9 files changed, 288 insertions, 127 deletions
| diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index d7d45a739..32a70f7cd 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -25,6 +25,7 @@ import (  	"strings"  	"time" +	"github.com/superseriousbusiness/gotosocial/internal/cache"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -34,6 +35,7 @@ import (  type accountDB struct {  	config *config.Config  	conn   *DBConn +	cache  *cache.AccountCache  }  func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { @@ -45,60 +47,80 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {  }  func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { -	account := new(gtsmodel.Account) - -	q := a.newAccountQ(account). -		Where("account.id = ?", id) - -	err := q.Scan(ctx) -	if err != nil { -		return nil, a.conn.ProcessError(err) -	} -	return account, nil +	return a.getAccount( +		ctx, +		func() (*gtsmodel.Account, bool) { +			return a.cache.GetByID(id) +		}, +		func(account *gtsmodel.Account) error { +			return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx) +		}, +	)  }  func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { -	account := new(gtsmodel.Account) - -	q := a.newAccountQ(account). -		Where("account.uri = ?", uri) +	return a.getAccount( +		ctx, +		func() (*gtsmodel.Account, bool) { +			return a.cache.GetByURI(uri) +		}, +		func(account *gtsmodel.Account) error { +			return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx) +		}, +	) +} -	err := q.Scan(ctx) -	if err != nil { -		return nil, a.conn.ProcessError(err) -	} -	return account, nil +func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) { +	return a.getAccount( +		ctx, +		func() (*gtsmodel.Account, bool) { +			return a.cache.GetByURL(url) +		}, +		func(account *gtsmodel.Account) error { +			return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx) +		}, +	)  } -func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { -	account := new(gtsmodel.Account) +func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) { +	// Attempt to fetch cached account +	account, cached := cacheGet() -	q := a.newAccountQ(account). -		Where("account.url = ?", uri) +	if !cached { +		account = >smodel.Account{} -	err := q.Scan(ctx) -	if err != nil { -		return nil, a.conn.ProcessError(err) +		// Not cached! Perform database query +		err := dbQuery(account) +		if err != nil { +			return nil, a.conn.ProcessError(err) +		} + +		// Place in the cache +		a.cache.Put(account)  	} +  	return account, nil  }  func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {  	if strings.TrimSpace(account.ID) == "" { +		// TODO: we should not need this check here  		return nil, errors.New("account had no ID")  	} +	// Update the account's last-used  	account.UpdatedAt = time.Now() -	q := a.conn. -		NewUpdate(). -		Model(account). -		WherePK() - -	_, err := q.Exec(ctx) +	// Update the account model in the DB +	_, err := a.conn.NewUpdate().Model(account).WherePK().Exec(ctx)  	if err != nil {  		return nil, a.conn.ProcessError(err)  	} + +	// Place updated account in cache +	// (this will replace existing, i.e. invalidating) +	a.cache.Put(account) +  	return account, nil  } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 248232fe3..6fcc56e51 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -91,6 +91,15 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)  		conn = WrapDBConn(bun.NewDB(sqldb, pgdialect.New()), log)  	case dbTypeSqlite:  		// SQLITE + +		// Drop anything fancy from DB address +		c.DBConfig.Address = strings.Split(c.DBConfig.Address, "?")[0] +		c.DBConfig.Address = strings.TrimPrefix(c.DBConfig.Address, "file:") + +		// Append our own SQLite preferences +		c.DBConfig.Address = "file:" + c.DBConfig.Address + "?cache=shared" + +		// Open new DB instance  		var err error  		sqldb, err = sql.Open("sqlite", c.DBConfig.Address)  		if err != nil { @@ -98,7 +107,7 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)  		}  		conn = WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()), log) -		if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") { +		if c.DBConfig.Address == "file::memory:?cache=shared" {  			log.Warn("sqlite in-memory database should only be used for debugging")  			// don't close connections on disconnect -- otherwise @@ -121,11 +130,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)  		conn.RegisterModel(t)  	} +	accounts := &accountDB{config: c, conn: conn, cache: cache.NewAccountCache()} +  	ps := &bunDBService{ -		Account: &accountDB{ -			config: c, -			conn:   conn, -		}, +		Account: accounts,  		Admin: &adminDB{  			config: c,  			conn:   conn, @@ -165,9 +173,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)  			conn:   conn,  		},  		Status: &statusDB{ -			config: c, -			conn:   conn, -			cache:  cache.NewStatusCache(), +			config:   c, +			conn:     conn, +			cache:    cache.NewStatusCache(), +			accounts: accounts,  		},  		Timeline: &timelineDB{  			config: c, diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go index 698adff3d..abaebcebd 100644 --- a/internal/db/bundb/conn.go +++ b/internal/db/bundb/conn.go @@ -12,6 +12,8 @@ import (  // dbConn wrapps a bun.DB conn to provide SQL-type specific additional functionality  type DBConn struct { +	// TODO: move *Config here, no need to be in each struct type +  	errProc func(error) db.Error // errProc is the SQL-type specific error processor  	log     *logrus.Logger       // log is the logger passed with this DBConn  	*bun.DB                      // DB is the underlying bun.DB connection @@ -35,6 +37,24 @@ func WrapDBConn(dbConn *bun.DB, log *logrus.Logger) *DBConn {  	}  } +func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error { +	// Acquire a new transaction +	tx, err := conn.BeginTx(ctx, nil) +	if err != nil { +		return conn.ProcessError(err) +	} + +	// Perform supplied transaction +	if err = fn(tx); err != nil { +		tx.Rollback() //nolint +		return conn.ProcessError(err) +	} + +	// Finally, commit transaction +	err = tx.Commit() +	return conn.ProcessError(err) +} +  // ProcessError processes an error to replace any known values with our own db.Error types,  // making it easier to catch specific situations (e.g. no rows, already exists, etc)  func (conn *DBConn) ProcessError(err error) db.Error { diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 56b752593..64d896527 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -237,7 +237,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI  	if _, err := r.conn.  		NewInsert().  		Model(follow). -		On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI). +		On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI).  		Exec(ctx); err != nil {  		return nil, r.conn.ProcessError(err)  	} @@ -298,7 +298,7 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str  	if localOnly {  		q = q.ColumnExpr("follow.*"). -			Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)"). +			Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)").  			Where("follow.target_account_id = ?", accountID).  			WhereGroup(" AND ", whereEmptyOrNull("a.domain"))  	} else { diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go new file mode 100644 index 000000000..dcc71b37c --- /dev/null +++ b/internal/db/bundb/relationship_test.go @@ -0,0 +1,124 @@ +/* +   GoToSocial +   Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( +	"context" +	"errors" +	"testing" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/testrig" +) + +type RelationshipTestSuite struct { +	BunDBStandardTestSuite +} + +func (suite *RelationshipTestSuite) SetupSuite() { +	suite.testTokens = testrig.NewTestTokens() +	suite.testClients = testrig.NewTestClients() +	suite.testApplications = testrig.NewTestApplications() +	suite.testUsers = testrig.NewTestUsers() +	suite.testAccounts = testrig.NewTestAccounts() +	suite.testAttachments = testrig.NewTestAttachments() +	suite.testStatuses = testrig.NewTestStatuses() +	suite.testTags = testrig.NewTestTags() +	suite.testMentions = testrig.NewTestMentions() +} + +func (suite *RelationshipTestSuite) SetupTest() { +	suite.config = testrig.NewTestConfig() +	suite.db = testrig.NewTestDB() +	suite.log = testrig.NewTestLog() + +	testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *RelationshipTestSuite) TearDownTest() { +	testrig.StandardDBTeardown(suite.db) +} + +func (suite *RelationshipTestSuite) TestIsBlocked() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestGetBlock() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestGetRelationship() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestIsFollowing() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) TestIsMutualFollowing() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) AcceptFollowRequest() { +	for _, account := range suite.testAccounts { +		_, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID") +		if err != nil && !errors.Is(err, db.ErrNoEntries) { +			suite.Suite.Fail("error accepting follow request: %v", err) +		} +	} +} + +func (suite *RelationshipTestSuite) GetAccountFollowRequests() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) GetAccountFollows() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) CountAccountFollows() { +	suite.Suite.T().Skip("TODO: implement") +} + +func (suite *RelationshipTestSuite) GetAccountFollowedBy() { +	// TODO: more comprehensive tests here + +	for _, account := range suite.testAccounts { +		var err error + +		_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false) +		if err != nil { +			suite.Suite.Fail("error checking accounts followed by: %v", err) +		} + +		_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) +		if err != nil { +			suite.Suite.Fail("error checking localOnly accounts followed by: %v", err) +		} +	} +} + +func (suite *RelationshipTestSuite) CountAccountFollowedBy() { +	suite.Suite.T().Skip("TODO: implement") +} + +func TestRelationshipTestSuite(t *testing.T) { +	suite.Run(t, new(RelationshipTestSuite)) +} diff --git a/internal/db/bundb/sqlite-test.db b/internal/db/bundb/sqlite-test.dbBinary files differ deleted file mode 100644 index ed3b25ee3..000000000 --- a/internal/db/bundb/sqlite-test.db +++ /dev/null diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 1d5acf0fc..9464cfadf 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -21,7 +21,6 @@ package bundb  import (  	"container/list"  	"context" -	"errors"  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/cache" @@ -35,6 +34,11 @@ type statusDB struct {  	config *config.Config  	conn   *DBConn  	cache  *cache.StatusCache + +	// TODO: keep method definitions in same place but instead have receiver +	//       all point to one single "db" type, so they can all share methods +	//       and caches where necessary +	accounts *accountDB  }  func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { @@ -51,30 +55,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {  		Relation("CreatedWithApplication")  } -func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status { -	if status.InReplyToID != "" && status.InReplyTo == nil { -		// TODO: do we want to keep this possibly recursive strategy? - -		if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached { -			status.InReplyTo = inReplyTo -		} else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil { -			status.InReplyTo = inReplyTo -		} -	} - -	if status.BoostOfID != "" && status.BoostOf == nil { -		// TODO: do we want to keep this possibly recursive strategy? - -		if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached { -			status.BoostOf = boostOf -		} else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil { -			status.BoostOf = boostOf -		} -	} - -	return status -} -  func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {  	return s.conn.  		NewSelect(). @@ -85,64 +65,79 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {  }  func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { -	if status, cached := s.cache.GetByID(id); cached { -		return status, nil -	} - -	status := >smodel.Status{} - -	q := s.newStatusQ(status). -		Where("status.id = ?", id) - -	err := q.Scan(ctx) -	if err != nil { -		return nil, s.conn.ProcessError(err) -	} - -	s.cache.Put(status) -	return s.getAttachedStatuses(ctx, status), nil +	return s.getStatus( +		ctx, +		func() (*gtsmodel.Status, bool) { +			return s.cache.GetByID(id) +		}, +		func(status *gtsmodel.Status) error { +			return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx) +		}, +	)  }  func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { -	if status, cached := s.cache.GetByURI(uri); cached { -		return status, nil -	} - -	status := >smodel.Status{} +	return s.getStatus( +		ctx, +		func() (*gtsmodel.Status, bool) { +			return s.cache.GetByURI(uri) +		}, +		func(status *gtsmodel.Status) error { +			return s.newStatusQ(status).Where("LOWER(status.uri) = LOWER(?)", uri).Scan(ctx) +		}, +	) +} -	q := s.newStatusQ(status). -		Where("LOWER(status.uri) = LOWER(?)", uri) +func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { +	return s.getStatus( +		ctx, +		func() (*gtsmodel.Status, bool) { +			return s.cache.GetByURL(url) +		}, +		func(status *gtsmodel.Status) error { +			return s.newStatusQ(status).Where("LOWER(status.url) = LOWER(?)", url).Scan(ctx) +		}, +	) +} -	err := q.Scan(ctx) -	if err != nil { -		return nil, s.conn.ProcessError(err) -	} +func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) { +	// Attempt to fetch cached status +	status, cached := cacheGet() -	s.cache.Put(status) -	return s.getAttachedStatuses(ctx, status), nil -} +	if !cached { +		status = >smodel.Status{} -func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { -	if status, cached := s.cache.GetByURL(url); cached { -		return status, nil -	} +		// Not cached! Perform database query +		err := dbQuery(status) +		if err != nil { +			return nil, s.conn.ProcessError(err) +		} -	status := >smodel.Status{} +		// If there is boosted, fetch from DB also +		if status.BoostOfID != "" { +			boostOf, err := s.GetStatusByID(ctx, status.BoostOfID) +			if err == nil { +				status.BoostOf = boostOf +			} +		} -	q := s.newStatusQ(status). -		Where("LOWER(status.url) = LOWER(?)", url) +		// Place in the cache +		s.cache.Put(status) +	} -	err := q.Scan(ctx) +	// Set the status author account +	author, err := s.accounts.GetAccountByID(ctx, status.AccountID)  	if err != nil { -		return nil, s.conn.ProcessError(err) +		return nil, err  	} -	s.cache.Put(status) -	return s.getAttachedStatuses(ctx, status), nil +	// Return the prepared status +	status.Account = author +	return status, nil  }  func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { -	transaction := func(ctx context.Context, tx bun.Tx) error { +	return s.conn.RunInTx(ctx, func(tx bun.Tx) error {  		// create links between this status and any emojis it uses  		for _, i := range status.EmojiIDs {  			if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ @@ -174,10 +169,10 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er  			}  		} +		// Finally, insert the status  		_, err := tx.NewInsert().Model(status).Exec(ctx)  		return err -	} -	return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction)) +	})  }  func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { @@ -210,12 +205,8 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu  	children := []*gtsmodel.Status{}  	for e := foundStatuses.Front(); e != nil; e = e.Next() { -		entry, ok := e.Value.(*gtsmodel.Status) -		if !ok { -			panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) -		} -  		// only append children, not the overall parent status +		entry := e.Value.(*gtsmodel.Status)  		if entry.ID != status.ID {  			children = append(children, entry)  		} @@ -242,11 +233,7 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,  	for _, child := range immediateChildren {  	insertLoop:  		for e := foundStatuses.Front(); e != nil; e = e.Next() { -			entry, ok := e.Value.(*gtsmodel.Status) -			if !ok { -				panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) -			} - +			entry := e.Value.(*gtsmodel.Status)  			if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {  				foundStatuses.InsertAfter(child, e)  				break insertLoop diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go index 4f846441b..7acc86ff9 100644 --- a/internal/db/bundb/status_test.go +++ b/internal/db/bundb/status_test.go @@ -105,10 +105,9 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {  	suite.NotNil(status)  	suite.NotNil(status.Account)  	suite.NotNil(status.CreatedWithApplication) -	suite.NotEmpty(status.Mentions)  	suite.NotEmpty(status.MentionIDs) -	suite.NotNil(status.InReplyTo) -	suite.NotNil(status.InReplyToAccount) +	suite.NotEmpty(status.InReplyToID) +	suite.NotEmpty(status.InReplyToAccountID)  }  func (suite *StatusTestSuite) TestGetStatusTwice() { diff --git a/internal/db/status.go b/internal/db/status.go index 7430433c4..f26f8942e 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -26,13 +26,13 @@ import (  // Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.  type Status interface { -	// GetStatusByID returns one status from the database, with all rel fields populated (if possible). +	// GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs  	GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error) -	// GetStatusByURI returns one status from the database, with all rel fields populated (if possible). +	// GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs  	GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error) -	// GetStatusByURL returns one status from the database, with all rel fields populated (if possible). +	// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs  	GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)  	// PutStatus stores one status in the database. | 
