diff options
author | 2021-08-25 15:34:33 +0200 | |
---|---|---|
committer | 2021-08-25 15:34:33 +0200 | |
commit | 2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch) | |
tree | 4ddeac479b923db38090aac8bd9209f3646851c1 /internal/db | |
parent | Manually approves followers (#146) (diff) | |
download | gotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz |
Pg to bun (#148)
* start moving to bun
* changing more stuff
* more
* and yet more
* tests passing
* seems stable now
* more big changes
* small fix
* little fixes
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/account.go | 26 | ||||
-rw-r--r-- | internal/db/admin.go | 11 | ||||
-rw-r--r-- | internal/db/basic.go | 31 | ||||
-rw-r--r-- | internal/db/bundb/account.go (renamed from internal/db/pg/account.go) | 155 | ||||
-rw-r--r-- | internal/db/bundb/account_test.go (renamed from internal/db/pg/account_test.go) | 22 | ||||
-rw-r--r-- | internal/db/bundb/admin.go (renamed from internal/db/pg/admin.go) | 149 | ||||
-rw-r--r-- | internal/db/bundb/basic.go | 179 | ||||
-rw-r--r-- | internal/db/bundb/basic_test.go | 68 | ||||
-rw-r--r-- | internal/db/bundb/bundb.go (renamed from internal/db/pg/pg.go) | 150 | ||||
-rw-r--r-- | internal/db/bundb/bundb_test.go (renamed from internal/db/pg/pg_test.go) | 4 | ||||
-rw-r--r-- | internal/db/bundb/domain.go (renamed from internal/db/pg/domain.go) | 30 | ||||
-rw-r--r-- | internal/db/bundb/instance.go (renamed from internal/db/pg/instance.go) | 66 | ||||
-rw-r--r-- | internal/db/bundb/media.go (renamed from internal/db/pg/media.go) | 18 | ||||
-rw-r--r-- | internal/db/bundb/mention.go (renamed from internal/db/pg/mention.go) | 22 | ||||
-rw-r--r-- | internal/db/bundb/notification.go (renamed from internal/db/pg/notification.go) | 25 | ||||
-rw-r--r-- | internal/db/bundb/relationship.go (renamed from internal/db/pg/relationship.go) | 172 | ||||
-rw-r--r-- | internal/db/bundb/session.go | 85 | ||||
-rw-r--r-- | internal/db/bundb/status.go | 375 | ||||
-rw-r--r-- | internal/db/bundb/status_test.go (renamed from internal/db/pg/status_test.go) | 22 | ||||
-rw-r--r-- | internal/db/bundb/timeline.go (renamed from internal/db/pg/timeline.go) | 89 | ||||
-rw-r--r-- | internal/db/bundb/util.go | 78 | ||||
-rw-r--r-- | internal/db/db.go | 9 | ||||
-rw-r--r-- | internal/db/domain.go | 13 | ||||
-rw-r--r-- | internal/db/instance.go | 14 | ||||
-rw-r--r-- | internal/db/media.go | 8 | ||||
-rw-r--r-- | internal/db/mention.go | 10 | ||||
-rw-r--r-- | internal/db/notification.go | 10 | ||||
-rw-r--r-- | internal/db/pg/basic.go | 205 | ||||
-rw-r--r-- | internal/db/pg/status.go | 318 | ||||
-rw-r--r-- | internal/db/pg/util.go | 25 | ||||
-rw-r--r-- | internal/db/relationship.go | 30 | ||||
-rw-r--r-- | internal/db/session.go | 31 | ||||
-rw-r--r-- | internal/db/status.go | 36 | ||||
-rw-r--r-- | internal/db/timeline.go | 12 |
34 files changed, 1461 insertions, 1037 deletions
diff --git a/internal/db/account.go b/internal/db/account.go index 0e1575f9b..058a89859 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -19,6 +19,7 @@ package db import ( + "context" "time" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -27,40 +28,43 @@ import ( // Account contains functions related to account getting/setting/creation. type Account interface { // GetAccountByID returns one account with the given ID, or an error if something goes wrong. - GetAccountByID(id string) (*gtsmodel.Account, Error) + GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, Error) // GetAccountByURI returns one account with the given URI, or an error if something goes wrong. - GetAccountByURI(uri string) (*gtsmodel.Account, Error) + GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) // GetAccountByURL returns one account with the given URL, or an error if something goes wrong. - GetAccountByURL(uri string) (*gtsmodel.Account, Error) + GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error) + + // UpdateAccount updates one account by ID. + UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) // GetLocalAccountByUsername returns an account on this instance by its username. - GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error) + GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error) // GetAccountFaves fetches faves/likes created by the target accountID. - GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error) + GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, Error) // GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID. - CountAccountStatuses(accountID string) (int, Error) + CountAccountStatuses(ctx context.Context, accountID string) (int, Error) // GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can // be very memory intensive so you probably shouldn't do this! // In case of no entries, a 'no entries' error will be returned - GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error) + GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error) - GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error) + GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error) // GetAccountLastPosted simply gets the timestamp of the most recent post by the account. // // The returned time will be zero if account has never posted anything. - GetAccountLastPosted(accountID string) (time.Time, Error) + GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, Error) // SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment. - SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error + SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error // GetInstanceAccount returns the instance account for the given domain. // If domain is empty, this instance account will be returned. - GetInstanceAccount(domain string) (*gtsmodel.Account, Error) + GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, Error) } diff --git a/internal/db/admin.go b/internal/db/admin.go index aa2b22f47..24d628e84 100644 --- a/internal/db/admin.go +++ b/internal/db/admin.go @@ -19,6 +19,7 @@ package db import ( + "context" "net" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -28,26 +29,26 @@ import ( type Admin interface { // IsUsernameAvailable checks whether a given username is available on our domain. // Returns an error if the username is already taken, or something went wrong in the db. - IsUsernameAvailable(username string) Error + IsUsernameAvailable(ctx context.Context, username string) (bool, Error) // IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain. // Return an error if: // A) the email is already associated with an account // B) we block signups from this email domain // C) something went wrong in the db - IsEmailAvailable(email string) Error + IsEmailAvailable(ctx context.Context, email string) (bool, Error) // NewSignup creates a new user in the database with the given parameters. // By the time this function is called, it should be assumed that all the parameters have passed validation! - NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error) + NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error) // CreateInstanceAccount creates an account in the database with the same username as the instance host value. // Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'. // This is needed for things like serving files that belong to the instance and not an individual user/account. - CreateInstanceAccount() Error + CreateInstanceAccount(ctx context.Context) Error // CreateInstanceInstance creates an instance in the database with the same domain as the instance host value. // Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'. // This is needed for things like serving instance information through /api/v1/instance - CreateInstanceInstance() Error + CreateInstanceInstance(ctx context.Context) Error } diff --git a/internal/db/basic.go b/internal/db/basic.go index 729920bba..cf65ddc09 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -24,15 +24,11 @@ import "context" type Basic interface { // CreateTable creates a table for the given interface. // For implementations that don't use tables, this can just return nil. - CreateTable(i interface{}) Error + CreateTable(ctx context.Context, i interface{}) Error // DropTable drops the table for the given interface. // For implementations that don't use tables, this can just return nil. - DropTable(i interface{}) Error - - // RegisterTable registers a table for use in many2many relations. - // For implementations that don't use tables, or many2many relations, this can just return nil. - RegisterTable(i interface{}) Error + DropTable(ctx context.Context, i interface{}) Error // Stop should stop and close the database connection cleanly, returning an error if this is not possible. // If the database implementation doesn't need to be stopped, this can just return nil. @@ -45,43 +41,38 @@ type Basic interface { // for other implementations (for example, in-memory) it might just be the key of a map. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetByID(id string, i interface{}) Error + GetByID(ctx context.Context, id string, i interface{}) Error // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the // name of the key to select from. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetWhere(where []Where, i interface{}) Error + GetWhere(ctx context.Context, where []Where, i interface{}) Error // GetAll will try to get all entries of type i. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetAll(i interface{}) Error + GetAll(ctx context.Context, i interface{}) Error // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - Put(i interface{}) Error - - // Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/ - // It is up to the implementation to figure out how to store it, and using what key. - // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - Upsert(i interface{}, conflictColumn string) Error + Put(ctx context.Context, i interface{}) Error // UpdateByID updates i with id id. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - UpdateByID(id string, i interface{}) Error + UpdateByID(ctx context.Context, id string, i interface{}) Error // UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value. - UpdateOneByID(id string, key string, value interface{}, i interface{}) Error + UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) Error // UpdateWhere updates column key of interface i with the given value, where the given parameters apply. - UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error + UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error // DeleteByID removes i with id id. // If i didn't exist anyway, then no error should be returned. - DeleteByID(id string, i interface{}) Error + DeleteByID(ctx context.Context, id string, i interface{}) Error // DeleteWhere deletes i where key = value // If i didn't exist anyway, then no error should be returned. - DeleteWhere(where []Where, i interface{}) Error + DeleteWhere(ctx context.Context, where []Where, i interface{}) Error } diff --git a/internal/db/pg/account.go b/internal/db/bundb/account.go index 3889c6601..7ebb79a15 100644 --- a/internal/db/pg/account.go +++ b/internal/db/bundb/account.go @@ -16,70 +16,90 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" "errors" "fmt" + "strings" "time" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type accountDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query { - return a.conn.Model(account). +func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { + return a.conn. + NewSelect(). + Model(account). Relation("AvatarMediaAttachment"). Relation("HeaderMediaAttachment") } -func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} +func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) q := a.newAccountQ(account). Where("account.id = ?", id) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} +func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) q := a.newAccountQ(account). Where("account.uri = ?", uri) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} +func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) q := a.newAccountQ(account). Where("account.url = ?", uri) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} +func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { + if strings.TrimSpace(account.ID) == "" { + return nil, errors.New("account had no ID") + } + + account.UpdatedAt = time.Now() + + q := a.conn. + NewUpdate(). + Model(account). + WherePK() + + _, err := q.Exec(ctx) + + err = processErrorResponse(err) + + return account, err +} + +func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) q := a.newAccountQ(account) @@ -90,29 +110,31 @@ func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Err } else { q = q. Where("account.username = ?", domain). - Where("? IS NULL", pg.Ident("domain")) + Where("? IS NULL", bun.Ident("domain")) } - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) { - status := >smodel.Status{} +func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) { + status := new(gtsmodel.Status) - q := a.conn.Model(status). + q := a.conn. + NewSelect(). + Model(status). Order("id DESC"). Limit(1). Where("account_id = ?", accountID). Column("created_at") - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return status.CreatedAt, err } -func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { +func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { if mediaAttachment.Avatar && mediaAttachment.Header { return errors.New("one media attachment cannot be both header and avatar") } @@ -127,51 +149,66 @@ func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAtta } // TODO: there are probably more side effects here that need to be handled - if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil { + if _, err := a.conn. + NewInsert(). + Model(mediaAttachment). + Exec(ctx); err != nil { return err } - if _, err := a.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil { + if _, err := a.conn. + NewUpdate(). + Model(>smodel.Account{}). + Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID). + Where("id = ?", accountID). + Exec(ctx); err != nil { return err } return nil } -func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} +func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) q := a.newAccountQ(account). Where("username = ?", username). - Where("? IS NULL", pg.Ident("domain")) + Where("? IS NULL", bun.Ident("domain")) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return account, err } -func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) { - faves := []*gtsmodel.StatusFave{} +func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) { + faves := new([]*gtsmodel.StatusFave) - if err := a.conn.Model(&faves). + if err := a.conn. + NewSelect(). + Model(faves). Where("account_id = ?", accountID). - Select(); err != nil { - if err == pg.ErrNoRows { - return faves, nil - } + Scan(ctx); err != nil { return nil, err } - return faves, nil + return *faves, nil } -func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) { - return a.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count() +func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { + return a.conn. + NewSelect(). + Model(>smodel.Status{}). + Where("account_id = ?", accountID). + Count(ctx) } -func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) { +func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) { a.log.Debugf("getting statuses for account %s", accountID) statuses := []*gtsmodel.Status{} - q := a.conn.Model(&statuses).Order("id DESC") + q := a.conn. + NewSelect(). + Model(&statuses). + Order("id DESC") + if accountID != "" { q = q.Where("account_id = ?", accountID) } @@ -181,27 +218,26 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli } if excludeReplies { - q = q.Where("? IS NULL", pg.Ident("in_reply_to_id")) + q = q.Where("? IS NULL", bun.Ident("in_reply_to_id")) } if pinnedOnly { q = q.Where("pinned = ?", true) } - if mediaOnly { - q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) { - return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil - }) - } - if maxID != "" { q = q.Where("id < ?", maxID) } - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } + if mediaOnly { + q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { + return q. + WhereOr("? IS NOT NULL", bun.Ident("attachments")). + WhereOr("attachments != '{}'") + }) + } + + if err := q.Scan(ctx); err != nil { return nil, err } @@ -213,10 +249,12 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli return statuses, nil } -func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { +func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { blocks := []*gtsmodel.Block{} - fq := a.conn.Model(&blocks). + fq := a.conn. + NewSelect(). + Model(&blocks). Where("block.account_id = ?", accountID). Relation("TargetAccount"). Order("block.id DESC") @@ -233,11 +271,8 @@ func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID str fq = fq.Limit(limit) } - err := fq.Select() + err := fq.Scan(ctx) if err != nil { - if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries - } return nil, "", "", err } diff --git a/internal/db/pg/account_test.go b/internal/db/bundb/account_test.go index 7ea5ff39a..7174b781d 100644 --- a/internal/db/pg/account_test.go +++ b/internal/db/bundb/account_test.go @@ -16,17 +16,19 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg_test +package bundb_test import ( + "context" "testing" + "time" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/testrig" ) type AccountTestSuite struct { - PGStandardTestSuite + BunDBStandardTestSuite } func (suite *AccountTestSuite) SetupSuite() { @@ -54,7 +56,7 @@ func (suite *AccountTestSuite) TearDownTest() { } func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { - account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID) + account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID) if err != nil { suite.FailNow(err.Error()) } @@ -65,6 +67,20 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { suite.NotEmpty(account.HeaderMediaAttachment.URL) } +func (suite *AccountTestSuite) TestUpdateAccount() { + testAccount := suite.testAccounts["local_account_1"] + + testAccount.DisplayName = "new display name!" + + _, err := suite.db.UpdateAccount(context.Background(), testAccount) + suite.NoError(err) + + updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID) + suite.NoError(err) + suite.Equal("new display name!", updated.DisplayName) + suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second) +} + func TestAccountTestSuite(t *testing.T) { suite.Run(t, new(AccountTestSuite)) } diff --git a/internal/db/pg/admin.go b/internal/db/bundb/admin.go index 854f56ef0..67a1e8a0d 100644 --- a/internal/db/pg/admin.go +++ b/internal/db/bundb/admin.go @@ -16,76 +16,76 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" "crypto/rand" "crypto/rsa" + "database/sql" "fmt" "net" "net/mail" "strings" "time" - "github.com/go-pg/pg/v10" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" "golang.org/x/crypto/bcrypt" ) type adminDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (a *adminDB) IsUsernameAvailable(username string) db.Error { - // if no error we fail because it means we found something - // if error but it's not pg.ErrNoRows then we fail - // if err is pg.ErrNoRows we're good, we found nothing so continue - if err := a.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil { - return fmt.Errorf("username %s already in use", username) - } else if err != pg.ErrNoRows { - return fmt.Errorf("db error: %s", err) - } - return nil +func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { + q := a.conn. + NewSelect(). + Model(>smodel.Account{}). + Where("username = ?", username). + Where("domain = ?", nil) + + return notExists(ctx, q) } -func (a *adminDB) IsEmailAvailable(email string) db.Error { +func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) { // parse the domain from the email m, err := mail.ParseAddress(email) if err != nil { - return fmt.Errorf("error parsing email address %s: %s", email, err) + return false, fmt.Errorf("error parsing email address %s: %s", email, err) } domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ // check if the email domain is blocked - if err := a.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil { + if err := a.conn. + NewSelect(). + Model(>smodel.EmailDomainBlock{}). + Where("domain = ?", domain). + Scan(ctx); err == nil { // fail because we found something - return fmt.Errorf("email domain %s is blocked", domain) - } else if err != pg.ErrNoRows { - // fail because we got an unexpected error - return fmt.Errorf("db error: %s", err) + return false, fmt.Errorf("email domain %s is blocked", domain) + } else if err != sql.ErrNoRows { + return false, processErrorResponse(err) } // check if this email is associated with a user already - if err := a.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil { - // fail because we found something - return fmt.Errorf("email %s already in use", email) - } else if err != pg.ErrNoRows { - // fail because we got an unexpected error - return fmt.Errorf("db error: %s", err) - } - return nil + q := a.conn. + NewSelect(). + Model(>smodel.User{}). + Where("email = ?", email). + WhereOr("unconfirmed_email = ?", email) + + return notExists(ctx, q) } -func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { +func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { a.log.Errorf("error creating new rsa key: %s", err) @@ -94,13 +94,12 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool // if something went wrong while creating a user, we might already have an account, so check here first... acct := >smodel.Account{} - err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select() + err = a.conn.NewSelect(). + Model(acct). + Where("username = ?", username). + Where("? IS NULL", bun.Ident("domain")). + Scan(ctx) if err != nil { - // there's been an actual error - if err != pg.ErrNoRows { - return nil, fmt.Errorf("db error checking existence of account: %s", err) - } - // we just don't have an account yet create one newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host) newAccountID, err := id.NewRandomULID() @@ -125,7 +124,10 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool FollowingURI: newAccountURIs.FollowingURI, FeaturedCollectionURI: newAccountURIs.CollectionURI, } - if _, err = a.conn.Model(acct).Insert(); err != nil { + if _, err = a.conn. + NewInsert(). + Model(acct). + Exec(ctx); err != nil { return nil, err } } @@ -161,15 +163,33 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool u.Moderator = true } - if _, err = a.conn.Model(u).Insert(); err != nil { + if _, err = a.conn. + NewInsert(). + Model(u). + Exec(ctx); err != nil { return nil, err } return u, nil } -func (a *adminDB) CreateInstanceAccount() db.Error { +func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { username := a.config.Host + + // check if instance account already exists + existsQ := a.conn. + NewSelect(). + Model(>smodel.Account{}). + Where("username = ?", username). + Where("? IS NULL", bun.Ident("domain")) + count, err := existsQ.Count(ctx) + if err != nil && count == 1 { + a.log.Infof("instance account %s already exists", username) + return nil + } else if err != sql.ErrNoRows { + return processErrorResponse(err) + } + key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { a.log.Errorf("error creating new rsa key: %s", err) @@ -198,19 +218,36 @@ func (a *adminDB) CreateInstanceAccount() db.Error { FollowingURI: newAccountURIs.FollowingURI, FeaturedCollectionURI: newAccountURIs.CollectionURI, } - inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert() - if err != nil { + + insertQ := a.conn. + NewInsert(). + Model(acct) + + if _, err := insertQ.Exec(ctx); err != nil { return err } - if inserted { - a.log.Infof("created instance account %s with id %s", username, acct.ID) - } else { - a.log.Infof("instance account %s already exists with id %s", username, acct.ID) - } + + a.log.Infof("instance account %s CREATED with id %s", username, acct.ID) return nil } -func (a *adminDB) CreateInstanceInstance() db.Error { +func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { + domain := a.config.Host + + // check if instance entry already exists + existsQ := a.conn. + NewSelect(). + Model(>smodel.Instance{}). + Where("domain = ?", domain) + + count, err := existsQ.Count(ctx) + if err != nil && count == 1 { + a.log.Infof("instance instance %s already exists", domain) + return nil + } else if err != sql.ErrNoRows { + return processErrorResponse(err) + } + iID, err := id.NewRandomULID() if err != nil { return err @@ -218,18 +255,18 @@ func (a *adminDB) CreateInstanceInstance() db.Error { i := >smodel.Instance{ ID: iID, - Domain: a.config.Host, - Title: a.config.Host, + Domain: domain, + Title: domain, URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host), } - inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert() - if err != nil { + + insertQ := a.conn. + NewInsert(). + Model(i) + + if _, err := insertQ.Exec(ctx); err != nil { return err } - if inserted { - a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID) - } else { - a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID) - } + a.log.Infof("created instance instance %s with id %s", domain, i.ID) return nil } diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go new file mode 100644 index 000000000..983b6b810 --- /dev/null +++ b/internal/db/bundb/basic.go @@ -0,0 +1,179 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "errors" + "strings" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/uptrace/bun" +) + +type basicDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error { + _, err := b.conn.NewInsert().Model(i).Exec(ctx) + if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return db.ErrAlreadyExists + } + return err +} + +func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error { + q := b.conn. + NewSelect(). + Model(i). + Where("id = ?", id) + + return processErrorResponse(q.Scan(ctx)) +} + +func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { + if len(where) == 0 { + return errors.New("no queries provided") + } + + q := b.conn.NewSelect().Model(i) + for _, w := range where { + + if w.Value == nil { + q = q.Where("? IS NULL", bun.Ident(w.Key)) + } else { + if w.CaseInsensitive { + q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value) + } else { + q = q.Where("? = ?", bun.Safe(w.Key), w.Value) + } + } + } + + return processErrorResponse(q.Scan(ctx)) +} + +func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error { + q := b.conn. + NewSelect(). + Model(i) + + return processErrorResponse(q.Scan(ctx)) +} + +func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error { + q := b.conn. + NewDelete(). + Model(i). + Where("id = ?", id) + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { + if len(where) == 0 { + return errors.New("no queries provided") + } + + q := b.conn. + NewDelete(). + Model(i) + + for _, w := range where { + q = q.Where("? = ?", bun.Safe(w.Key), w.Value) + } + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error { + q := b.conn. + NewUpdate(). + Model(i). + WherePK() + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error { + q := b.conn.NewUpdate(). + Model(i). + Set("? = ?", bun.Safe(key), value). + WherePK() + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error { + q := b.conn.NewUpdate().Model(i) + + for _, w := range where { + if w.Value == nil { + q = q.Where("? IS NULL", bun.Ident(w.Key)) + } else { + if w.CaseInsensitive { + q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value) + } else { + q = q.Where("? = ?", bun.Safe(w.Key), w.Value) + } + } + } + + q = q.Set("? = ?", bun.Safe(key), value) + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { + _, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx) + return err +} + +func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error { + _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx) + return processErrorResponse(err) +} + +func (b *basicDB) IsHealthy(ctx context.Context) db.Error { + return b.conn.Ping() +} + +func (b *basicDB) Stop(ctx context.Context) db.Error { + b.log.Info("closing db connection") + if err := b.conn.Close(); err != nil { + // only cancel if there's a problem closing the db + return err + } + return nil +} diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go new file mode 100644 index 000000000..9189618c9 --- /dev/null +++ b/internal/db/bundb/basic_test.go @@ -0,0 +1,68 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type BasicTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *BasicTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testTags = testrig.NewTestTags() + suite.testMentions = testrig.NewTestMentions() +} + +func (suite *BasicTestSuite) SetupTest() { + suite.config = testrig.NewTestConfig() + suite.db = testrig.NewTestDB() + suite.log = testrig.NewTestLog() + + testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *BasicTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *BasicTestSuite) TestGetAccountByID() { + testAccount := suite.testAccounts["local_account_1"] + + a := >smodel.Account{} + err := suite.db.GetByID(context.Background(), testAccount.ID, a) + suite.NoError(err) +} + +func TestBasicTestSuite(t *testing.T) { + suite.Run(t, new(BasicTestSuite)) +} diff --git a/internal/db/pg/pg.go b/internal/db/bundb/bundb.go index 0437baf02..49ed09cbd 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/bundb/bundb.go @@ -16,12 +16,13 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" "crypto/tls" "crypto/x509" + "database/sql" "encoding/pem" "errors" "fmt" @@ -29,14 +30,20 @@ import ( "strings" "time" - "github.com/go-pg/pg/extra/pgdebug" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/stdlib" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" +) + +const ( + dbTypePostgres = "postgres" + dbTypeSqlite = "sqlite" ) var registerTables []interface{} = []interface{}{ @@ -44,8 +51,8 @@ var registerTables []interface{} = []interface{}{ >smodel.StatusToTag{}, } -// postgresService satisfies the DB interface -type postgresService struct { +// bunDBService satisfies the DB interface +type bunDBService struct { db.Account db.Admin db.Basic @@ -55,130 +62,115 @@ type postgresService struct { db.Mention db.Notification db.Relationship + db.Session db.Status db.Timeline config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. -// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection. -func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { - for _, t := range registerTables { - // https://pg.uptrace.dev/orm/many-to-many-relation/ - orm.RegisterTable(t) - } - - opts, err := derivePGOptions(c) - if err != nil { - return nil, fmt.Errorf("could not create postgres service: %s", err) - } - log.Debugf("using pg options: %+v", opts) - - // create a connection - pgCtx, cancel := context.WithCancel(ctx) - conn := pg.Connect(opts).WithContext(pgCtx) - - // this will break the logfmt format we normally log in, - // since we can't choose where pg outputs to and it defaults to - // stdout. So use this option with care! - if log.GetLevel() >= logrus.TraceLevel { - conn.AddQueryHook(pgdebug.DebugHook{ - // Print all queries. - Verbose: true, - }) +// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. +// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. +func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { + var sqldb *sql.DB + var conn *bun.DB + + // depending on the database type we're trying to create, we need to use a different driver... + switch strings.ToLower(c.DBConfig.Type) { + case dbTypePostgres: + // POSTGRES + opts, err := deriveBunDBPGOptions(c) + if err != nil { + return nil, fmt.Errorf("could not create bundb postgres options: %s", err) + } + sqldb = stdlib.OpenDB(*opts) + conn = bun.NewDB(sqldb, pgdialect.New()) + case dbTypeSqlite: + // SQLITE + // TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite + default: + return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type)) } // actually *begin* the connection so that we can tell if the db is there and listening - if err := conn.Ping(ctx); err != nil { - cancel() + if err := conn.Ping(); err != nil { return nil, fmt.Errorf("db connection error: %s", err) } + log.Info("connected to database") - // print out discovered postgres version - var version string - if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil { - cancel() - return nil, fmt.Errorf("db connection error: %s", err) + for _, t := range registerTables { + // https://bun.uptrace.dev/orm/many-to-many-relation/ + conn.RegisterModel(t) } - log.Infof("connected to postgres version: %s", version) - ps := &postgresService{ + ps := &bunDBService{ Account: &accountDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Admin: &adminDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Basic: &basicDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Domain: &domainDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Instance: &instanceDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Media: &mediaDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Mention: &mentionDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Notification: ¬ificationDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Relationship: &relationshipDB{ config: c, conn: conn, log: log, - cancel: cancel, + }, + Session: &sessionDB{ + config: c, + conn: conn, + log: log, }, Status: &statusDB{ config: c, conn: conn, log: log, - cancel: cancel, }, Timeline: &timelineDB{ config: c, conn: conn, log: log, - cancel: cancel, }, config: c, conn: conn, log: log, - cancel: cancel, } - // we can confidently return this useable postgres service now + // we can confidently return this useable service now return ps, nil } @@ -186,9 +178,9 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge HANDY STUFF */ -// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options +// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options // with sensible defaults, or an error if it's not satisfied by the provided config. -func derivePGOptions(c *config.Config) (*pg.Options, error) { +func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) { if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres { return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type) } @@ -266,18 +258,16 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { tlsConfig.RootCAs = certPool } - // We can rely on the pg library we're using to set - // sensible defaults for everything we don't set here. - options := &pg.Options{ - Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port), - User: c.DBConfig.User, - Password: c.DBConfig.Password, - Database: c.DBConfig.Database, - ApplicationName: c.ApplicationName, - TLSConfig: tlsConfig, - } + cfg, _ := pgx.ParseConfig("") + cfg.Host = c.DBConfig.Address + cfg.Port = uint16(c.DBConfig.Port) + cfg.User = c.DBConfig.User + cfg.Password = c.DBConfig.Password + cfg.TLSConfig = tlsConfig + cfg.Database = c.DBConfig.Database + cfg.PreferSimpleProtocol = true - return options, nil + return cfg, nil } /* @@ -286,9 +276,9 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { // TODO: move these to the type converter, it's bananas that they're here and not there -func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) { +func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) { ogAccount := >smodel.Account{} - if err := ps.conn.Model(ogAccount).Where("id = ?", originAccountID).Select(); err != nil { + if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil { return nil, err } @@ -333,14 +323,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori // match username + account, case insensitive if local { // local user -- should have a null domain - err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("? IS NULL", pg.Ident("domain")).Select() + err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx) } else { // remote user -- should have domain defined - err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("LOWER(?) = LOWER(?)", pg.Ident("domain"), domain).Select() + err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx) } if err != nil { - if err == pg.ErrNoRows { + if err == sql.ErrNoRows { // no result found for this username/domain so just don't include it as a mencho and carry on about our business ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain) continue @@ -364,14 +354,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori return menchies, nil } -func (ps *postgresService) TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) { +func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) { newTags := []*gtsmodel.Tag{} for _, t := range tags { tag := >smodel.Tag{} // we can use selectorinsert here to create the new tag if it doesn't exist already // inserted will be true if this is a new tag we just created - if err := ps.conn.Model(tag).Where("LOWER(?) = LOWER(?)", pg.Ident("name"), t).Select(); err != nil { - if err == pg.ErrNoRows { + if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil { + if err == sql.ErrNoRows { // tag doesn't exist yet so populate it newID, err := id.NewRandomULID() if err != nil { @@ -400,13 +390,13 @@ func (ps *postgresService) TagStringsToTags(tags []string, originAccountID strin return newTags, nil } -func (ps *postgresService) EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) { +func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) { newEmojis := []*gtsmodel.Emoji{} for _, e := range emojis { emoji := >smodel.Emoji{} - err := ps.conn.Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Select() + err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx) if err != nil { - if err == pg.ErrNoRows { + if err == sql.ErrNoRows { // no result found for this username/domain so just don't include it as an emoji and carry on about our business ps.log.Debugf("no emoji found with shortcode %s, skipping it", e) continue diff --git a/internal/db/pg/pg_test.go b/internal/db/bundb/bundb_test.go index c1e10abdf..b789375af 100644 --- a/internal/db/pg/pg_test.go +++ b/internal/db/bundb/bundb_test.go @@ -16,7 +16,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg_test +package bundb_test import ( "github.com/sirupsen/logrus" @@ -27,7 +27,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -type PGStandardTestSuite struct { +type BunDBStandardTestSuite struct { // standard suite interfaces suite.Suite config *config.Config diff --git a/internal/db/pg/domain.go b/internal/db/bundb/domain.go index 4e9b2ab48..6aa2b8ffe 100644 --- a/internal/db/pg/domain.go +++ b/internal/db/bundb/domain.go @@ -16,48 +16,46 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" "net/url" - "github.com/go-pg/pg/v10" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" ) type domainDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) { +func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { if domain == "" { return false, nil } - blocked, err := d.conn. + q := d.conn. + NewSelect(). Model(>smodel.DomainBlock{}). Where("LOWER(domain) = LOWER(?)", domain). - Exists() + Limit(1) - err = processErrorResponse(err) - - return blocked, err + return exists(ctx, q) } -func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) { +func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { // filter out any doubles uniqueDomains := util.UniqueStrings(domains) for _, domain := range uniqueDomains { - if blocked, err := d.IsDomainBlocked(domain); err != nil { + if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil { return false, err } else if blocked { return blocked, nil @@ -68,16 +66,16 @@ func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) { return false, nil } -func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) { +func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) { domain := uri.Hostname() - return d.IsDomainBlocked(domain) + return d.IsDomainBlocked(ctx, domain) } -func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) { +func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) { domains := []string{} for _, uri := range uris { domains = append(domains, uri.Hostname()) } - return d.AreDomainsBlocked(domains) + return d.AreDomainsBlocked(ctx, domains) } diff --git a/internal/db/pg/instance.go b/internal/db/bundb/instance.go index 968832ca5..f9364346e 100644 --- a/internal/db/pg/instance.go +++ b/internal/db/bundb/instance.go @@ -16,43 +16,50 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" - "github.com/go-pg/pg/v10" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type instanceDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) { - q := i.conn.Model(&[]*gtsmodel.Account{}) +func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { + q := i.conn. + NewSelect(). + Model(&[]*gtsmodel.Account{}) if domain == i.config.Host { // if the domain is *this* domain, just count where the domain field is null - q = q.Where("? IS NULL", pg.Ident("domain")) + q = q.Where("? IS NULL", bun.Ident("domain")) } else { q = q.Where("domain = ?", domain) } // don't count the instance account or suspended users - q = q.Where("username != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) + q = q. + Where("username != ?", domain). + Where("? IS NULL", bun.Ident("suspended_at")) - return q.Count() + count, err := q.Count(ctx) + + return count, processErrorResponse(err) } -func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) { - q := i.conn.Model(&[]*gtsmodel.Status{}) +func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { + q := i.conn. + NewSelect(). + Model(&[]*gtsmodel.Status{}) if domain == i.config.Host { // if the domain is *this* domain, just count where local is true @@ -63,30 +70,39 @@ func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) { Where("account.domain = ?", domain) } - return q.Count() + count, err := q.Count(ctx) + + return count, processErrorResponse(err) } -func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) { - q := i.conn.Model(&[]*gtsmodel.Instance{}) +func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { + q := i.conn. + NewSelect(). + Model(&[]*gtsmodel.Instance{}) if domain == i.config.Host { // if the domain is *this* domain, just count other instances it knows about // exclude domains that are blocked - q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) + q = q.Where("domain != ?", domain).Where("? IS NULL", bun.Ident("suspended_at")) } else { // TODO: implement federated domain counting properly for remote domains return 0, nil } - return q.Count() + count, err := q.Count(ctx) + + return count, processErrorResponse(err) } -func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { +func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { i.log.Debug("GetAccountsForInstance") accounts := []*gtsmodel.Account{} - q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC") + q := i.conn.NewSelect(). + Model(&accounts). + Where("domain = ?", domain). + Order("id DESC") if maxID != "" { q = q.Where("id < ?", maxID) @@ -96,17 +112,7 @@ func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) q = q.Limit(limit) } - err := q.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } - return nil, err - } - - if len(accounts) == 0 { - return nil, db.ErrNoEntries - } + err := processErrorResponse(q.Scan(ctx)) - return accounts, nil + return accounts, err } diff --git a/internal/db/pg/media.go b/internal/db/bundb/media.go index 618030af3..04e55ca62 100644 --- a/internal/db/pg/media.go +++ b/internal/db/bundb/media.go @@ -16,38 +16,38 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type mediaDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (m *mediaDB) newMediaQ(i interface{}) *orm.Query { - return m.conn.Model(i). +func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery { + return m.conn. + NewSelect(). + Model(i). Relation("Account") } -func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) { +func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) { attachment := >smodel.MediaAttachment{} q := m.newMediaQ(attachment). Where("media_attachment.id = ?", id) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return attachment, err } diff --git a/internal/db/pg/mention.go b/internal/db/bundb/mention.go index b31f07b67..a444f9b5f 100644 --- a/internal/db/pg/mention.go +++ b/internal/db/bundb/mention.go @@ -16,25 +16,23 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type mentionDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc cache cache.Cache } @@ -67,14 +65,16 @@ func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) { return mention, true } -func (m *mentionDB) newMentionQ(i interface{}) *orm.Query { - return m.conn.Model(i). +func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { + return m.conn. + NewSelect(). + Model(i). Relation("Status"). Relation("OriginAccount"). Relation("TargetAccount") } -func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) { +func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { if mention, cached := m.mentionCached(id); cached { return mention, nil } @@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) { q := m.newMentionQ(mention). Where("mention.id = ?", id) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) if err == nil && mention != nil { m.cacheMention(id, mention) @@ -93,11 +93,11 @@ func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) { return mention, err } -func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) { +func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) { mentions := []*gtsmodel.Mention{} for _, i := range ids { - mention, err := m.GetMention(i) + mention, err := m.GetMention(ctx, i) if err != nil { return nil, processErrorResponse(err) } diff --git a/internal/db/pg/notification.go b/internal/db/bundb/notification.go index 281a76d85..1c30837ec 100644 --- a/internal/db/pg/notification.go +++ b/internal/db/bundb/notification.go @@ -16,25 +16,23 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type notificationDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc cache cache.Cache } @@ -67,14 +65,16 @@ func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, return notification, true } -func (n *notificationDB) newNotificationQ(i interface{}) *orm.Query { - return n.conn.Model(i). +func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery { + return n.conn. + NewSelect(). + Model(i). Relation("OriginAccount"). Relation("TargetAccount"). Relation("Status") } -func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.Error) { +func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { if notification, cached := n.notificationCached(id); cached { return notification, nil } @@ -84,7 +84,7 @@ func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db. q := n.newNotificationQ(notification). Where("notification.id = ?", id) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) if err == nil && notification != nil { n.cacheNotification(id, notification) @@ -93,10 +93,11 @@ func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db. return notification, err } -func (n *notificationDB) GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { +func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { // begin by selecting just the IDs notifIDs := []*gtsmodel.Notification{} q := n.conn. + NewSelect(). Model(¬ifIDs). Column("id"). Where("target_account_id = ?", accountID). @@ -114,7 +115,7 @@ func (n *notificationDB) GetNotifications(accountID string, limit int, maxID str q = q.Limit(limit) } - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) if err != nil { return nil, err } @@ -123,7 +124,7 @@ func (n *notificationDB) GetNotifications(accountID string, limit int, maxID str // reason for this is that for each notif, we can instead get it from our cache if it's cached notifications := []*gtsmodel.Notification{} for _, notifID := range notifIDs { - notif, err := n.GetNotification(notifID.ID) + notif, err := n.GetNotification(ctx, notifID.ID) errP := processErrorResponse(err) if errP != nil { return nil, errP diff --git a/internal/db/pg/relationship.go b/internal/db/bundb/relationship.go index 76bd50c76..ccc604baf 100644 --- a/internal/db/pg/relationship.go +++ b/internal/db/bundb/relationship.go @@ -16,44 +16,49 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" + "database/sql" "fmt" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type relationshipDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query { - return r.conn.Model(block). +func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery { + return r.conn. + NewSelect(). + Model(block). Relation("Account"). Relation("TargetAccount") } -func (r *relationshipDB) newFollowQ(follow interface{}) *orm.Query { - return r.conn.Model(follow). +func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery { + return r.conn. + NewSelect(). + Model(follow). Relation("Account"). Relation("TargetAccount") } -func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) { +func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) { q := r.conn. + NewSelect(). Model(>smodel.Block{}). Where("account_id = ?", account1). - Where("target_account_id = ?", account2) + Where("target_account_id = ?", account2). + Limit(1) if eitherDirection { q = q. @@ -61,30 +66,36 @@ func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirec Where("account_id = ?", account2) } - return q.Exists() + return exists(ctx, q) } -func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) { +func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) { block := >smodel.Block{} q := r.newBlockQ(block). Where("block.account_id = ?", account1). Where("block.target_account_id = ?", account2) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return block, err } -func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { +func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { rel := >smodel.Relationship{ ID: targetAccount, } // check if the requesting account follows the target account follow := >smodel.Follow{} - if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil { - if err != pg.ErrNoRows { + if err := r.conn. + NewSelect(). + Model(follow). + Where("account_id = ?", requestingAccount). + Where("target_account_id = ?", targetAccount). + Limit(1). + Scan(ctx); err != nil { + if err != sql.ErrNoRows { // a proper error return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) } @@ -100,75 +111,101 @@ func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount } // check if the target account follows the requesting account - followedBy, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() + count, err := r.conn. + NewSelect(). + Model(>smodel.Follow{}). + Where("account_id = ?", targetAccount). + Where("target_account_id = ?", requestingAccount). + Limit(1). + Count(ctx) if err != nil { return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) } - rel.FollowedBy = followedBy + rel.FollowedBy = count > 0 // check if the requesting account blocks the target account - blocking, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() + count, err = r.conn.NewSelect(). + Model(>smodel.Block{}). + Where("account_id = ?", requestingAccount). + Where("target_account_id = ?", targetAccount). + Limit(1). + Count(ctx) if err != nil { return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) } - rel.Blocking = blocking + rel.Blocking = count > 0 // check if the target account blocks the requesting account - blockedBy, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() + count, err = r.conn. + NewSelect(). + Model(>smodel.Block{}). + Where("account_id = ?", targetAccount). + Where("target_account_id = ?", requestingAccount). + Limit(1). + Count(ctx) if err != nil { return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) } - rel.BlockedBy = blockedBy + rel.BlockedBy = count > 0 // check if there's a pending following request from requesting account to target account - requested, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() + count, err = r.conn. + NewSelect(). + Model(>smodel.FollowRequest{}). + Where("account_id = ?", requestingAccount). + Where("target_account_id = ?", targetAccount). + Limit(1). + Count(ctx) if err != nil { return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) } - rel.Requested = requested + rel.Requested = count > 0 return rel, nil } -func (r *relationshipDB) IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { +func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { if sourceAccount == nil || targetAccount == nil { return false, nil } q := r.conn. + NewSelect(). Model(>smodel.Follow{}). Where("account_id = ?", sourceAccount.ID). - Where("target_account_id = ?", targetAccount.ID) + Where("target_account_id = ?", targetAccount.ID). + Limit(1) - return q.Exists() + return exists(ctx, q) } -func (r *relationshipDB) IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { +func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { if sourceAccount == nil || targetAccount == nil { return false, nil } q := r.conn. + NewSelect(). Model(>smodel.FollowRequest{}). Where("account_id = ?", sourceAccount.ID). Where("target_account_id = ?", targetAccount.ID) - return q.Exists() + return exists(ctx, q) } -func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { +func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { if account1 == nil || account2 == nil { return false, nil } // make sure account 1 follows account 2 - f1, err := r.IsFollowing(account1, account2) + f1, err := r.IsFollowing(ctx, account1, account2) if err != nil { return false, processErrorResponse(err) } // make sure account 2 follows account 1 - f2, err := r.IsFollowing(account2, account1) + f2, err := r.IsFollowing(ctx, account2, account1) if err != nil { return false, processErrorResponse(err) } @@ -176,14 +213,16 @@ func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 return f1 && f2, nil } -func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { +func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { // make sure the original follow request exists fr := >smodel.FollowRequest{} - if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil { - if err == pg.ErrMultiRows { - return nil, db.ErrNoEntries - } - return nil, err + if err := r.conn. + NewSelect(). + Model(fr). + Where("account_id = ?", originAccountID). + Where("target_account_id = ?", targetAccountID). + Scan(ctx); err != nil { + return nil, processErrorResponse(err) } // create a new follow to 'replace' the request with @@ -195,82 +234,95 @@ func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccou } // if the follow already exists, just update the URI -- we don't need to do anything else - if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil { - return nil, err + if _, err := r.conn. + NewInsert(). + Model(follow). + On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI). + Exec(ctx); err != nil { + return nil, processErrorResponse(err) } // now remove the follow request - if _, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil { - return nil, err + if _, err := r.conn. + NewDelete(). + Model(>smodel.FollowRequest{}). + Where("account_id = ?", originAccountID). + Where("target_account_id = ?", targetAccountID). + Exec(ctx); err != nil { + return nil, processErrorResponse(err) } return follow, nil } -func (r *relationshipDB) GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, db.Error) { +func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) { followRequests := []*gtsmodel.FollowRequest{} q := r.newFollowQ(&followRequests). Where("target_account_id = ?", accountID) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return followRequests, err } -func (r *relationshipDB) GetAccountFollows(accountID string) ([]*gtsmodel.Follow, db.Error) { +func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) { follows := []*gtsmodel.Follow{} q := r.newFollowQ(&follows). Where("account_id = ?", accountID) - err := processErrorResponse(q.Select()) + err := processErrorResponse(q.Scan(ctx)) return follows, err } -func (r *relationshipDB) CountAccountFollows(accountID string, localOnly bool) (int, db.Error) { +func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { return r.conn. + NewSelect(). Model(&[]*gtsmodel.Follow{}). Where("account_id = ?", accountID). - Count() + Count(ctx) } -func (r *relationshipDB) GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { +func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { follows := []*gtsmodel.Follow{} - q := r.conn.Model(&follows) + q := r.conn. + NewSelect(). + Model(&follows) if localOnly { // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe - whereGroup := func(q *pg.Query) (*pg.Query, error) { + whereGroup := func(q *bun.SelectQuery) *bun.SelectQuery { q = q. - WhereOr("? IS NULL", pg.Ident("a.domain")). + WhereOr("? IS NULL", bun.Ident("a.domain")). WhereOr("a.domain = ?", "") - return q, nil + return q } q = q.ColumnExpr("follow.*"). Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)"). Where("follow.target_account_id = ?", accountID). - WhereGroup(whereGroup) + WhereGroup(" AND ", whereGroup) } else { q = q.Where("target_account_id = ?", accountID) } - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { + if err := q.Scan(ctx); err != nil { + if err == sql.ErrNoRows { return follows, nil } - return nil, err + return nil, processErrorResponse(err) } return follows, nil } -func (r *relationshipDB) CountAccountFollowedBy(accountID string, localOnly bool) (int, db.Error) { +func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { return r.conn. + NewSelect(). Model(&[]*gtsmodel.Follow{}). Where("target_account_id = ?", accountID). - Count() + Count(ctx) } diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go new file mode 100644 index 000000000..87e20673d --- /dev/null +++ b/internal/db/bundb/session.go @@ -0,0 +1,85 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "crypto/rand" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/uptrace/bun" +) + +type sessionDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { + rs := new(gtsmodel.RouterSession) + + q := s.conn. + NewSelect(). + Model(rs). + Limit(1) + + _, err := q.Exec(ctx) + + err = processErrorResponse(err) + + return rs, err +} + +func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { + auth := make([]byte, 32) + crypt := make([]byte, 32) + + if _, err := rand.Read(auth); err != nil { + return nil, err + } + if _, err := rand.Read(crypt); err != nil { + return nil, err + } + + rid, err := id.NewULID() + if err != nil { + return nil, err + } + + rs := >smodel.RouterSession{ + ID: rid, + Auth: auth, + Crypt: crypt, + } + + q := s.conn. + NewInsert(). + Model(rs) + + _, err = q.Exec(ctx) + + err = processErrorResponse(err) + + return rs, err +} diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go new file mode 100644 index 000000000..da8d8ca41 --- /dev/null +++ b/internal/db/bundb/status.go @@ -0,0 +1,375 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "container/list" + "context" + "errors" + "time" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type statusDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger + cache cache.Cache +} + +func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) { + if s.cache == nil { + s.cache = cache.New() + } + + if err := s.cache.Store(id, status); err != nil { + s.log.Panicf("statusDB: error storing in cache: %s", err) + } +} + +func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) { + if s.cache == nil { + s.cache = cache.New() + return nil, false + } + + sI, err := s.cache.Fetch(id) + if err != nil || sI == nil { + return nil, false + } + + status, ok := sI.(*gtsmodel.Status) + if !ok { + s.log.Panicf("statusDB: cached interface with key %s was not a status", id) + } + + return status, true +} + +func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { + return s.conn. + NewSelect(). + Model(status). + Relation("Attachments"). + Relation("Tags"). + Relation("Mentions"). + Relation("Emojis"). + Relation("Account"). + Relation("InReplyToAccount"). + Relation("BoostOfAccount"). + Relation("CreatedWithApplication") +} + +func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status { + if status.InReplyToID != "" && status.InReplyTo == nil { + if inReplyTo, cached := s.statusCached(status.InReplyToID); cached { + status.InReplyTo = inReplyTo + } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil { + status.InReplyTo = inReplyTo + } + } + + if status.BoostOfID != "" && status.BoostOf == nil { + if boostOf, cached := s.statusCached(status.BoostOfID); cached { + status.BoostOf = boostOf + } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil { + status.BoostOf = boostOf + } + } + + return status +} + +func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { + return s.conn. + NewSelect(). + Model(faves). + Relation("Account"). + Relation("TargetAccount"). + Relation("Status") +} + +func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(id); cached { + return status, nil + } + + status := new(gtsmodel.Status) + + q := s.newStatusQ(status). + Where("status.id = ?", id) + + err := processErrorResponse(q.Scan(ctx)) + + if err != nil { + return nil, err + } + + if status != nil { + s.cacheStatus(id, status) + } + + return s.getAttachedStatuses(ctx, status), err +} + +func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(uri); cached { + return status, nil + } + + status := >smodel.Status{} + + q := s.newStatusQ(status). + Where("LOWER(status.uri) = LOWER(?)", uri) + + err := processErrorResponse(q.Scan(ctx)) + + if err != nil { + return nil, err + } + + if status != nil { + s.cacheStatus(uri, status) + } + + return s.getAttachedStatuses(ctx, status), err +} + +func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(uri); cached { + return status, nil + } + + status := >smodel.Status{} + + q := s.newStatusQ(status). + Where("LOWER(status.url) = LOWER(?)", uri) + + err := processErrorResponse(q.Scan(ctx)) + + if err != nil { + return nil, err + } + + if status != nil { + s.cacheStatus(uri, status) + } + + return s.getAttachedStatuses(ctx, status), err +} + +func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { + transaction := func(ctx context.Context, tx bun.Tx) error { + // create links between this status and any emojis it uses + for _, i := range status.EmojiIDs { + if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ + StatusID: status.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + return err + } + } + + // create links between this status and any tags it uses + for _, i := range status.TagIDs { + if _, err := tx.NewInsert().Model(>smodel.StatusToTag{ + StatusID: status.ID, + TagID: i, + }).Exec(ctx); err != nil { + return err + } + } + + // change the status ID of the media attachments to the new status + for _, a := range status.Attachments { + a.StatusID = status.ID + a.UpdatedAt = time.Now() + if _, err := s.conn.NewUpdate().Model(a). + Where("id = ?", a.ID). + Exec(ctx); err != nil { + return err + } + } + + _, err := tx.NewInsert().Model(status).Exec(ctx) + return err + } + + return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction)) +} + +func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { + parents := []*gtsmodel.Status{} + s.statusParent(ctx, status, &parents, onlyDirect) + + return parents, nil +} + +func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { + if status.InReplyToID == "" { + return + } + + parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID) + if err == nil { + *foundStatuses = append(*foundStatuses, parentStatus) + } + + if onlyDirect { + return + } + + s.statusParent(ctx, parentStatus, foundStatuses, false) +} + +func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { + foundStatuses := &list.List{} + foundStatuses.PushFront(status) + s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID) + + children := []*gtsmodel.Status{} + for e := foundStatuses.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*gtsmodel.Status) + if !ok { + panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) + } + + // only append children, not the overall parent status + if entry.ID != status.ID { + children = append(children, entry) + } + } + + return children, nil +} + +func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { + immediateChildren := []*gtsmodel.Status{} + + q := s.conn. + NewSelect(). + Model(&immediateChildren). + Where("in_reply_to_id = ?", status.ID) + if minID != "" { + q = q.Where("status.id > ?", minID) + } + + if err := q.Scan(ctx); err != nil { + return + } + + for _, child := range immediateChildren { + insertLoop: + for e := foundStatuses.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*gtsmodel.Status) + if !ok { + panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) + } + + if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { + foundStatuses.InsertAfter(child, e) + break insertLoop + } + } + + // only do one loop if we only want direct children + if onlyDirect { + return + } + s.statusChildren(ctx, child, foundStatuses, false, minID) + } +} + +func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { + return s.conn.NewSelect().Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx) +} + +func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { + return s.conn.NewSelect().Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx) +} + +func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { + return s.conn.NewSelect().Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx) +} + +func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.StatusFave{}). + Where("status_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.Status{}). + Where("boost_of_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.StatusMute{}). + Where("status_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.StatusBookmark{}). + Where("status_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { + faves := []*gtsmodel.StatusFave{} + + q := s.newFaveQ(&faves). + Where("status_id = ?", status.ID) + + err := processErrorResponse(q.Scan(ctx)) + return faves, err +} + +func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { + reblogs := []*gtsmodel.Status{} + + q := s.newStatusQ(&reblogs). + Where("boost_of_id = ?", status.ID) + + err := processErrorResponse(q.Scan(ctx)) + return reblogs, err +} diff --git a/internal/db/pg/status_test.go b/internal/db/bundb/status_test.go index 8a185757c..513000577 100644 --- a/internal/db/pg/status_test.go +++ b/internal/db/bundb/status_test.go @@ -16,9 +16,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg_test +package bundb_test import ( + "context" "fmt" "testing" "time" @@ -28,7 +29,7 @@ import ( ) type StatusTestSuite struct { - PGStandardTestSuite + BunDBStandardTestSuite } func (suite *StatusTestSuite) SetupSuite() { @@ -56,8 +57,9 @@ func (suite *StatusTestSuite) TearDownTest() { } func (suite *StatusTestSuite) TestGetStatusByID() { - status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID) + status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID) if err != nil { + fmt.Println(err.Error()) suite.FailNow(err.Error()) } suite.NotNil(status) @@ -70,7 +72,7 @@ func (suite *StatusTestSuite) TestGetStatusByID() { } func (suite *StatusTestSuite) TestGetStatusByURI() { - status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) + status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) if err != nil { suite.FailNow(err.Error()) } @@ -84,7 +86,7 @@ func (suite *StatusTestSuite) TestGetStatusByURI() { } func (suite *StatusTestSuite) TestGetStatusWithExtras() { - status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID) + status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["admin_account_status_1"].ID) if err != nil { suite.FailNow(err.Error()) } @@ -97,7 +99,7 @@ func (suite *StatusTestSuite) TestGetStatusWithExtras() { } func (suite *StatusTestSuite) TestGetStatusWithMention() { - status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID) + status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_2_status_5"].ID) if err != nil { suite.FailNow(err.Error()) } @@ -112,18 +114,18 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() { func (suite *StatusTestSuite) TestGetStatusTwice() { before1 := time.Now() - _, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) + _, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) suite.NoError(err) after1 := time.Now() duration1 := after1.Sub(before1) - fmt.Println(duration1.Nanoseconds()) + fmt.Println(duration1.Milliseconds()) before2 := time.Now() - _, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) + _, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) suite.NoError(err) after2 := time.Now() duration2 := after2.Sub(before2) - fmt.Println(duration2.Nanoseconds()) + fmt.Println(duration2.Milliseconds()) // second retrieval should be several orders faster since it will be cached now suite.Less(duration2, duration1) diff --git a/internal/db/pg/timeline.go b/internal/db/bundb/timeline.go index fa8b07aab..b62ad4c50 100644 --- a/internal/db/pg/timeline.go +++ b/internal/db/bundb/timeline.go @@ -16,43 +16,35 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. */ -package pg +package bundb import ( "context" + "database/sql" "sort" - "github.com/go-pg/pg/v10" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" ) type timelineDB struct { config *config.Config - conn *pg.DB + conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { +func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { statuses := []*gtsmodel.Status{} - q := t.conn.Model(&statuses) + q := t.conn. + NewSelect(). + Model(&statuses) q = q.ColumnExpr("status.*"). // Find out who accountID follows. Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id"). - // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows, - // OR statuses posted by accountID itself (since a user should be able to see their own statuses). - // - // This is equivalent to something like WHERE ... AND (... OR ...) - // See: https://pg.uptrace.dev/queries/#select - WhereGroup(func(q *pg.Query) (*pg.Query, error) { - q = q.WhereOr("f.account_id = ?", accountID). - WhereOr("status.account_id = ?", accountID) - return q, nil - }). // Sort by highest ID (newest) to lowest ID (oldest) Order("status.id DESC") @@ -81,29 +73,32 @@ func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID str q = q.Limit(limit) } - err := q.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } - return nil, err + // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows, + // OR statuses posted by accountID itself (since a user should be able to see their own statuses). + // + // This is equivalent to something like WHERE ... AND (... OR ...) + // See: https://bun.uptrace.dev/guide/queries.html#select + whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { + return q. + WhereOr("f.account_id = ?", accountID). + WhereOr("status.account_id = ?", accountID) } - if len(statuses) == 0 { - return nil, db.ErrNoEntries - } + q = q.WhereGroup(" AND ", whereGroup) - return statuses, nil + return statuses, processErrorResponse(q.Scan(ctx)) } -func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { +func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { statuses := []*gtsmodel.Status{} - q := t.conn.Model(&statuses). + q := t.conn. + NewSelect(). + Model(&statuses). Where("visibility = ?", gtsmodel.VisibilityPublic). - Where("? IS NULL", pg.Ident("in_reply_to_id")). - Where("? IS NULL", pg.Ident("in_reply_to_uri")). - Where("? IS NULL", pg.Ident("boost_of_id")). + Where("? IS NULL", bun.Ident("in_reply_to_id")). + Where("? IS NULL", bun.Ident("in_reply_to_uri")). + Where("? IS NULL", bun.Ident("boost_of_id")). Order("status.id DESC") if maxID != "" { @@ -126,28 +121,18 @@ func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID s q = q.Limit(limit) } - err := q.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } - return nil, err - } - - if len(statuses) == 0 { - return nil, db.ErrNoEntries - } - - return statuses, nil + return statuses, processErrorResponse(q.Scan(ctx)) } // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! // It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds. -func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) { +func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) { faves := []*gtsmodel.StatusFave{} - fq := t.conn.Model(&faves). + fq := t.conn. + NewSelect(). + Model(&faves). Where("account_id = ?", accountID). Order("id DESC") @@ -163,9 +148,9 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri fq = fq.Limit(limit) } - err := fq.Select() + err := fq.Scan(ctx) if err != nil { - if err == pg.ErrNoRows { + if err == sql.ErrNoRows { return nil, "", "", db.ErrNoEntries } return nil, "", "", err @@ -185,9 +170,13 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri } statuses := []*gtsmodel.Status{} - err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select() + err = t.conn. + NewSelect(). + Model(&statuses). + Where("id IN (?)", bun.In(in)). + Scan(ctx) if err != nil { - if err == pg.ErrNoRows { + if err == sql.ErrNoRows { return nil, "", "", db.ErrNoEntries } return nil, "", "", err diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go new file mode 100644 index 000000000..115d18de2 --- /dev/null +++ b/internal/db/bundb/util.go @@ -0,0 +1,78 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "strings" + + "database/sql" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/uptrace/bun" +) + +// processErrorResponse parses the given error and returns an appropriate DBError. +func processErrorResponse(err error) db.Error { + switch err { + case nil: + return nil + case sql.ErrNoRows: + return db.ErrNoEntries + default: + if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return db.ErrAlreadyExists + } + return err + } +} + +func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) { + count, err := q.Count(ctx) + + exists := count != 0 + + err = processErrorResponse(err) + + if err != nil { + if err == db.ErrNoEntries { + return false, nil + } + return false, err + } + + return exists, nil +} + +func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) { + count, err := q.Count(ctx) + + notExists := count == 0 + + err = processErrorResponse(err) + + if err != nil { + if err == db.ErrNoEntries { + return true, nil + } + return false, err + } + + return notExists, nil +} diff --git a/internal/db/db.go b/internal/db/db.go index d6ac883e4..ec94fcfe7 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -19,6 +19,8 @@ package db import ( + "context" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -38,6 +40,7 @@ type DB interface { Mention Notification Relationship + Session Status Timeline @@ -52,7 +55,7 @@ type DB interface { // // Note: this func doesn't/shouldn't do any manipulation of the accounts in the DB, it's just for checking // if they exist in the db and conveniently returning them if they do. - MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) + MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) // TagStringsToTags takes a slice of deduplicated, lowercase tags in the form "somehashtag", which have been // used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then @@ -61,7 +64,7 @@ type DB interface { // // Note: this func doesn't/shouldn't do any manipulation of the tags in the DB, it's just for checking // if they exist in the db already, and conveniently returning them, or creating new tag structs. - TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) + TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) // EmojiStringsToEmojis takes a slice of deduplicated, lowercase emojis in the form ":emojiname:", which have been // used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then @@ -69,5 +72,5 @@ type DB interface { // // Note: this func doesn't/shouldn't do any manipulation of the emoji in the DB, it's just for checking // if they exist in the db and conveniently returning them if they do. - EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) + EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) } diff --git a/internal/db/domain.go b/internal/db/domain.go index a6583c80c..df50a6770 100644 --- a/internal/db/domain.go +++ b/internal/db/domain.go @@ -18,19 +18,22 @@ package db -import "net/url" +import ( + "context" + "net/url" +) // Domain contains DB functions related to domains and domain blocks. type Domain interface { // IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`). - IsDomainBlocked(domain string) (bool, Error) + IsDomainBlocked(ctx context.Context, domain string) (bool, Error) // AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found. - AreDomainsBlocked(domains []string) (bool, Error) + AreDomainsBlocked(ctx context.Context, domains []string) (bool, Error) // IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`). - IsURIBlocked(uri *url.URL) (bool, Error) + IsURIBlocked(ctx context.Context, uri *url.URL) (bool, Error) // AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found. - AreURIsBlocked(uris []*url.URL) (bool, Error) + AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, Error) } diff --git a/internal/db/instance.go b/internal/db/instance.go index 1f7c83e4f..dcd978a81 100644 --- a/internal/db/instance.go +++ b/internal/db/instance.go @@ -18,19 +18,23 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Instance contains functions for instance-level actions (counting instance users etc.). type Instance interface { // CountInstanceUsers returns the number of known accounts registered with the given domain. - CountInstanceUsers(domain string) (int, Error) + CountInstanceUsers(ctx context.Context, domain string) (int, Error) // CountInstanceStatuses returns the number of known statuses posted from the given domain. - CountInstanceStatuses(domain string) (int, Error) + CountInstanceStatuses(ctx context.Context, domain string) (int, Error) // CountInstanceDomains returns the number of known instances known that the given domain federates with. - CountInstanceDomains(domain string) (int, Error) + CountInstanceDomains(ctx context.Context, domain string) (int, Error) // GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID. - GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) + GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) } diff --git a/internal/db/media.go b/internal/db/media.go index db4db3411..b779dd276 100644 --- a/internal/db/media.go +++ b/internal/db/media.go @@ -18,10 +18,14 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Media contains functions related to creating/getting/removing media attachments. type Media interface { // GetAttachmentByID gets a single attachment by its ID - GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error) + GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error) } diff --git a/internal/db/mention.go b/internal/db/mention.go index cb1c56dc1..b9b45546a 100644 --- a/internal/db/mention.go +++ b/internal/db/mention.go @@ -18,13 +18,17 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Mention contains functions for getting/creating mentions in the database. type Mention interface { // GetMention gets a single mention by ID - GetMention(id string) (*gtsmodel.Mention, Error) + GetMention(ctx context.Context, id string) (*gtsmodel.Mention, Error) // GetMentions gets multiple mentions. - GetMentions(ids []string) ([]*gtsmodel.Mention, Error) + GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error) } diff --git a/internal/db/notification.go b/internal/db/notification.go index 326f0f149..09c17f031 100644 --- a/internal/db/notification.go +++ b/internal/db/notification.go @@ -18,14 +18,18 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Notification contains functions for creating and getting notifications. type Notification interface { // GetNotifications returns a slice of notifications that pertain to the given accountID. // // Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest). - GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error) + GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error) // GetNotification returns one notification according to its id. - GetNotification(id string) (*gtsmodel.Notification, Error) + GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error) } diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go deleted file mode 100644 index 6e76b4450..000000000 --- a/internal/db/pg/basic.go +++ /dev/null @@ -1,205 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" -) - -type basicDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc -} - -func (b *basicDB) Put(i interface{}) db.Error { - _, err := b.conn.Model(i).Insert(i) - if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists - } - return err -} - -func (b *basicDB) GetByID(id string, i interface{}) db.Error { - if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - - } - return nil -} - -func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error { - if len(where) == 0 { - return errors.New("no queries provided") - } - - q := b.conn.Model(i) - for _, w := range where { - - if w.Value == nil { - q = q.Where("? IS NULL", pg.Ident(w.Key)) - } else { - if w.CaseInsensitive { - q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) - } else { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - } - } - - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) GetAll(i interface{}) db.Error { - if err := b.conn.Model(i).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) DeleteByID(id string, i interface{}) db.Error { - if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil { - // if there are no rows *anyway* then that's fine - // just return err if there's an actual error - if err != pg.ErrNoRows { - return err - } - } - return nil -} - -func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error { - if len(where) == 0 { - return errors.New("no queries provided") - } - - q := b.conn.Model(i) - for _, w := range where { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - - if _, err := q.Delete(); err != nil { - // if there are no rows *anyway* then that's fine - // just return err if there's an actual error - if err != pg.ErrNoRows { - return err - } - } - return nil -} - -func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error { - if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) UpdateByID(id string, i interface{}) db.Error { - if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error { - _, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update() - return err -} - -func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error { - q := b.conn.Model(i) - - for _, w := range where { - if w.Value == nil { - q = q.Where("? IS NULL", pg.Ident(w.Key)) - } else { - if w.CaseInsensitive { - q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) - } else { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - } - } - - q = q.Set("? = ?", pg.Safe(key), value) - - _, err := q.Update() - - return err -} - -func (b *basicDB) CreateTable(i interface{}) db.Error { - return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }) -} - -func (b *basicDB) DropTable(i interface{}) db.Error { - return b.conn.Model(i).DropTable(&orm.DropTableOptions{ - IfExists: true, - }) -} - -func (b *basicDB) RegisterTable(i interface{}) db.Error { - orm.RegisterTable(i) - return nil -} - -func (b *basicDB) IsHealthy(ctx context.Context) db.Error { - return b.conn.Ping(ctx) -} - -func (b *basicDB) Stop(ctx context.Context) db.Error { - b.log.Info("closing db connection") - if err := b.conn.Close(); err != nil { - // only cancel if there's a problem closing the db - b.cancel() - return err - } - return nil -} diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go deleted file mode 100644 index 99790428e..000000000 --- a/internal/db/pg/status.go +++ /dev/null @@ -1,318 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "container/list" - "context" - "errors" - "time" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/cache" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -type statusDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc - cache cache.Cache -} - -func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) { - if s.cache == nil { - s.cache = cache.New() - } - - if err := s.cache.Store(id, status); err != nil { - s.log.Panicf("statusDB: error storing in cache: %s", err) - } -} - -func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) { - if s.cache == nil { - s.cache = cache.New() - return nil, false - } - - sI, err := s.cache.Fetch(id) - if err != nil || sI == nil { - return nil, false - } - - status, ok := sI.(*gtsmodel.Status) - if !ok { - s.log.Panicf("statusDB: cached interface with key %s was not a status", id) - } - - return status, true -} - -func (s *statusDB) newStatusQ(status interface{}) *orm.Query { - return s.conn.Model(status). - Relation("Attachments"). - Relation("Tags"). - Relation("Mentions"). - Relation("Emojis"). - Relation("Account"). - Relation("InReplyTo"). - Relation("InReplyToAccount"). - Relation("BoostOf"). - Relation("BoostOfAccount"). - Relation("CreatedWithApplication") -} - -func (s *statusDB) newFaveQ(faves interface{}) *orm.Query { - return s.conn.Model(faves). - Relation("Account"). - Relation("TargetAccount"). - Relation("Status") -} - -func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(id); cached { - return status, nil - } - - status := >smodel.Status{} - - q := s.newStatusQ(status). - Where("status.id = ?", id) - - err := processErrorResponse(q.Select()) - - if err == nil && status != nil { - s.cacheStatus(id, status) - } - - return status, err -} - -func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(uri); cached { - return status, nil - } - - status := >smodel.Status{} - - q := s.newStatusQ(status). - Where("LOWER(status.uri) = LOWER(?)", uri) - - err := processErrorResponse(q.Select()) - - if err == nil && status != nil { - s.cacheStatus(uri, status) - } - - return status, err -} - -func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(uri); cached { - return status, nil - } - - status := >smodel.Status{} - - q := s.newStatusQ(status). - Where("LOWER(status.url) = LOWER(?)", uri) - - err := processErrorResponse(q.Select()) - - if err == nil && status != nil { - s.cacheStatus(uri, status) - } - - return status, err -} - -func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error { - transaction := func(tx *pg.Tx) error { - // create links between this status and any emojis it uses - for _, i := range status.EmojiIDs { - if _, err := tx.Model(>smodel.StatusToEmoji{ - StatusID: status.ID, - EmojiID: i, - }).Insert(); err != nil { - return err - } - } - - // create links between this status and any tags it uses - for _, i := range status.TagIDs { - if _, err := tx.Model(>smodel.StatusToTag{ - StatusID: status.ID, - TagID: i, - }).Insert(); err != nil { - return err - } - } - - // change the status ID of the media attachments to the new status - for _, a := range status.Attachments { - a.StatusID = status.ID - a.UpdatedAt = time.Now() - if _, err := s.conn.Model(a). - Where("id = ?", a.ID). - Update(); err != nil { - return err - } - } - - _, err := tx.Model(status).Insert() - return err - } - - return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction)) -} - -func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { - parents := []*gtsmodel.Status{} - s.statusParent(status, &parents, onlyDirect) - - return parents, nil -} - -func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { - if status.InReplyToID == "" { - return - } - - parentStatus, err := s.GetStatusByID(status.InReplyToID) - if err == nil { - *foundStatuses = append(*foundStatuses, parentStatus) - } - - if onlyDirect { - return - } - - s.statusParent(parentStatus, foundStatuses, false) -} - -func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { - foundStatuses := &list.List{} - foundStatuses.PushFront(status) - s.statusChildren(status, foundStatuses, onlyDirect, minID) - - children := []*gtsmodel.Status{} - for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - - // only append children, not the overall parent status - if entry.ID != status.ID { - children = append(children, entry) - } - } - - return children, nil -} - -func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { - immediateChildren := []*gtsmodel.Status{} - - q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID) - if minID != "" { - q = q.Where("status.id > ?", minID) - } - - if err := q.Select(); err != nil { - return - } - - for _, child := range immediateChildren { - insertLoop: - for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - - if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { - foundStatuses.InsertAfter(child, e) - break insertLoop - } - } - - // only do one loop if we only want direct children - if onlyDirect { - return - } - s.statusChildren(child, foundStatuses, false, minID) - } -} - -func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) { - return s.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count() -} - -func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) { - return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count() -} - -func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) { - return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count() -} - -func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { - faves := []*gtsmodel.StatusFave{} - - q := s.newFaveQ(&faves). - Where("status_id = ?", status.ID) - - err := processErrorResponse(q.Select()) - - return faves, err -} - -func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { - reblogs := []*gtsmodel.Status{} - - q := s.newStatusQ(&reblogs). - Where("boost_of_id = ?", status.ID) - - err := processErrorResponse(q.Select()) - - return reblogs, err -} diff --git a/internal/db/pg/util.go b/internal/db/pg/util.go deleted file mode 100644 index 17c09b720..000000000 --- a/internal/db/pg/util.go +++ /dev/null @@ -1,25 +0,0 @@ -package pg - -import ( - "strings" - - "github.com/go-pg/pg/v10" - "github.com/superseriousbusiness/gotosocial/internal/db" -) - -// processErrorResponse parses the given error and returns an appropriate DBError. -func processErrorResponse(err error) db.Error { - switch err { - case nil: - return nil - case pg.ErrNoRows: - return db.ErrNoEntries - case pg.ErrMultiRows: - return db.ErrMultipleEntries - default: - if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists - } - return err - } -} diff --git a/internal/db/relationship.go b/internal/db/relationship.go index 85f64d72b..804526425 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -18,54 +18,58 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Relationship contains functions for getting or modifying the relationship between two accounts. type Relationship interface { // IsBlocked checks whether account 1 has a block in place against block2. // If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1. - IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error) + IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error) // GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't. // // Because this is slower than Blocked, only use it if you need the actual Block struct for some reason, // not if you're just checking for the existence of a block. - GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error) + GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error) // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount. - GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error) + GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error) // IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out. - IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) + IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) // IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out. - IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) + IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error) // IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out. - IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error) + IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error) // AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table. // In other words, it should create the follow, and delete the existing follow request. // // It will return the newly created follow for further processing. - AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error) + AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error) // GetAccountFollowRequests returns all follow requests targeting the given account. - GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error) + GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, Error) // GetAccountFollows returns a slice of follows owned by the given accountID. - GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error) + GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, Error) // CountAccountFollows returns the amount of accounts that the given accountID is following. // // If localOnly is set to true, then only follows from *this instance* will be returned. - CountAccountFollows(accountID string, localOnly bool) (int, Error) + CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, Error) // GetAccountFollowedBy fetches follows that target given accountID. // // If localOnly is set to true, then only follows from *this instance* will be returned. - GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error) + GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, Error) // CountAccountFollowedBy returns the amounts that the given ID is followed by. - CountAccountFollowedBy(accountID string, localOnly bool) (int, Error) + CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, Error) } diff --git a/internal/db/session.go b/internal/db/session.go new file mode 100644 index 000000000..ae13dccce --- /dev/null +++ b/internal/db/session.go @@ -0,0 +1,31 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Session handles getting/creation of router sessions. +type Session interface { + GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error) + CreateSession(ctx context.Context) (*gtsmodel.RouterSession, Error) +} diff --git a/internal/db/status.go b/internal/db/status.go index 9d206c198..7430433c4 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -18,58 +18,62 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses. type Status interface { // GetStatusByID returns one status from the database, with all rel fields populated (if possible). - GetStatusByID(id string) (*gtsmodel.Status, Error) + GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error) // GetStatusByURI returns one status from the database, with all rel fields populated (if possible). - GetStatusByURI(uri string) (*gtsmodel.Status, Error) + GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error) // GetStatusByURL returns one status from the database, with all rel fields populated (if possible). - GetStatusByURL(uri string) (*gtsmodel.Status, Error) + GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error) // PutStatus stores one status in the database. - PutStatus(status *gtsmodel.Status) Error + PutStatus(ctx context.Context, status *gtsmodel.Status) Error // CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong - CountStatusReplies(status *gtsmodel.Status) (int, Error) + CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, Error) // CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong - CountStatusReblogs(status *gtsmodel.Status) (int, Error) + CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, Error) // CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong - CountStatusFaves(status *gtsmodel.Status) (int, Error) + CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, Error) // GetStatusParents gets the parent statuses of a given status. // // If onlyDirect is true, only the immediate parent will be returned. - GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error) + GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error) // GetStatusChildren gets the child statuses of a given status. // // If onlyDirect is true, only the immediate children will be returned. - GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error) + GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error) // IsStatusFavedBy checks if a given status has been faved by a given account ID - IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID - IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // IsStatusMutedBy checks if a given status has been muted by a given account ID - IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID - IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error) + IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) // GetStatusFaves returns a slice of faves/likes of the given status. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error) + GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error) // GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error) + GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, Error) } diff --git a/internal/db/timeline.go b/internal/db/timeline.go index 74aa5c781..83fb3a959 100644 --- a/internal/db/timeline.go +++ b/internal/db/timeline.go @@ -18,20 +18,24 @@ package db -import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) // Timeline contains functionality for retrieving home/public/faved etc timelines for an account. type Timeline interface { // GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id. // // Statuses should be returned in descending order of when they were created (newest first). - GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) + GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) // GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public. // It will use the given filters and try to return as many statuses as possible up to the limit. // // Statuses should be returned in descending order of when they were created (newest first). - GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) + GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) // GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved. // It will use the given filters and try to return as many statuses as possible up to the limit. @@ -40,5 +44,5 @@ type Timeline interface { // In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created. // // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers. - GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error) + GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error) } |