summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2021-08-25 15:34:33 +0200
committerLibravatar GitHub <noreply@github.com>2021-08-25 15:34:33 +0200
commit2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch)
tree4ddeac479b923db38090aac8bd9209f3646851c1 /internal/db
parentManually approves followers (#146) (diff)
downloadgotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz
Pg to bun (#148)
* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/account.go26
-rw-r--r--internal/db/admin.go11
-rw-r--r--internal/db/basic.go31
-rw-r--r--internal/db/bundb/account.go (renamed from internal/db/pg/account.go)155
-rw-r--r--internal/db/bundb/account_test.go (renamed from internal/db/pg/account_test.go)22
-rw-r--r--internal/db/bundb/admin.go (renamed from internal/db/pg/admin.go)149
-rw-r--r--internal/db/bundb/basic.go179
-rw-r--r--internal/db/bundb/basic_test.go68
-rw-r--r--internal/db/bundb/bundb.go (renamed from internal/db/pg/pg.go)150
-rw-r--r--internal/db/bundb/bundb_test.go (renamed from internal/db/pg/pg_test.go)4
-rw-r--r--internal/db/bundb/domain.go (renamed from internal/db/pg/domain.go)30
-rw-r--r--internal/db/bundb/instance.go (renamed from internal/db/pg/instance.go)66
-rw-r--r--internal/db/bundb/media.go (renamed from internal/db/pg/media.go)18
-rw-r--r--internal/db/bundb/mention.go (renamed from internal/db/pg/mention.go)22
-rw-r--r--internal/db/bundb/notification.go (renamed from internal/db/pg/notification.go)25
-rw-r--r--internal/db/bundb/relationship.go (renamed from internal/db/pg/relationship.go)172
-rw-r--r--internal/db/bundb/session.go85
-rw-r--r--internal/db/bundb/status.go375
-rw-r--r--internal/db/bundb/status_test.go (renamed from internal/db/pg/status_test.go)22
-rw-r--r--internal/db/bundb/timeline.go (renamed from internal/db/pg/timeline.go)89
-rw-r--r--internal/db/bundb/util.go78
-rw-r--r--internal/db/db.go9
-rw-r--r--internal/db/domain.go13
-rw-r--r--internal/db/instance.go14
-rw-r--r--internal/db/media.go8
-rw-r--r--internal/db/mention.go10
-rw-r--r--internal/db/notification.go10
-rw-r--r--internal/db/pg/basic.go205
-rw-r--r--internal/db/pg/status.go318
-rw-r--r--internal/db/pg/util.go25
-rw-r--r--internal/db/relationship.go30
-rw-r--r--internal/db/session.go31
-rw-r--r--internal/db/status.go36
-rw-r--r--internal/db/timeline.go12
34 files changed, 1461 insertions, 1037 deletions
diff --git a/internal/db/account.go b/internal/db/account.go
index 0e1575f9b..058a89859 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -19,6 +19,7 @@
package db
import (
+ "context"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -27,40 +28,43 @@ import (
// Account contains functions related to account getting/setting/creation.
type Account interface {
// GetAccountByID returns one account with the given ID, or an error if something goes wrong.
- GetAccountByID(id string) (*gtsmodel.Account, Error)
+ GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, Error)
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
- GetAccountByURI(uri string) (*gtsmodel.Account, Error)
+ GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
- GetAccountByURL(uri string) (*gtsmodel.Account, Error)
+ GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+
+ // UpdateAccount updates one account by ID.
+ UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
// GetLocalAccountByUsername returns an account on this instance by its username.
- GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error)
+ GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error)
// GetAccountFaves fetches faves/likes created by the target accountID.
- GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error)
+ GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, Error)
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
- CountAccountStatuses(accountID string) (int, Error)
+ CountAccountStatuses(ctx context.Context, accountID string) (int, Error)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
- GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
+ GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
- GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
+ GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
//
// The returned time will be zero if account has never posted anything.
- GetAccountLastPosted(accountID string) (time.Time, Error)
+ GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, Error)
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
- SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
+ SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
// GetInstanceAccount returns the instance account for the given domain.
// If domain is empty, this instance account will be returned.
- GetInstanceAccount(domain string) (*gtsmodel.Account, Error)
+ GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, Error)
}
diff --git a/internal/db/admin.go b/internal/db/admin.go
index aa2b22f47..24d628e84 100644
--- a/internal/db/admin.go
+++ b/internal/db/admin.go
@@ -19,6 +19,7 @@
package db
import (
+ "context"
"net"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -28,26 +29,26 @@ import (
type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
- IsUsernameAvailable(username string) Error
+ IsUsernameAvailable(ctx context.Context, username string) (bool, Error)
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
- IsEmailAvailable(email string) Error
+ IsEmailAvailable(ctx context.Context, email string) (bool, Error)
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
- NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
+ NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
- CreateInstanceAccount() Error
+ CreateInstanceAccount(ctx context.Context) Error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
- CreateInstanceInstance() Error
+ CreateInstanceInstance(ctx context.Context) Error
}
diff --git a/internal/db/basic.go b/internal/db/basic.go
index 729920bba..cf65ddc09 100644
--- a/internal/db/basic.go
+++ b/internal/db/basic.go
@@ -24,15 +24,11 @@ import "context"
type Basic interface {
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
- CreateTable(i interface{}) Error
+ CreateTable(ctx context.Context, i interface{}) Error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
- DropTable(i interface{}) Error
-
- // RegisterTable registers a table for use in many2many relations.
- // For implementations that don't use tables, or many2many relations, this can just return nil.
- RegisterTable(i interface{}) Error
+ DropTable(ctx context.Context, i interface{}) Error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
@@ -45,43 +41,38 @@ type Basic interface {
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetByID(id string, i interface{}) Error
+ GetByID(ctx context.Context, id string, i interface{}) Error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetWhere(where []Where, i interface{}) Error
+ GetWhere(ctx context.Context, where []Where, i interface{}) Error
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetAll(i interface{}) Error
+ GetAll(ctx context.Context, i interface{}) Error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Put(i interface{}) Error
-
- // Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
- // It is up to the implementation to figure out how to store it, and using what key.
- // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Upsert(i interface{}, conflictColumn string) Error
+ Put(ctx context.Context, i interface{}) Error
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- UpdateByID(id string, i interface{}) Error
+ UpdateByID(ctx context.Context, id string, i interface{}) Error
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
- UpdateOneByID(id string, key string, value interface{}, i interface{}) Error
+ UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) Error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
- UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error
+ UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
- DeleteByID(id string, i interface{}) Error
+ DeleteByID(ctx context.Context, id string, i interface{}) Error
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
- DeleteWhere(where []Where, i interface{}) Error
+ DeleteWhere(ctx context.Context, where []Where, i interface{}) Error
}
diff --git a/internal/db/pg/account.go b/internal/db/bundb/account.go
index 3889c6601..7ebb79a15 100644
--- a/internal/db/pg/account.go
+++ b/internal/db/bundb/account.go
@@ -16,70 +16,90 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
"errors"
"fmt"
+ "strings"
"time"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type accountDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
- return a.conn.Model(account).
+func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
+ return a.conn.
+ NewSelect().
+ Model(account).
Relation("AvatarMediaAttachment").
Relation("HeaderMediaAttachment")
}
-func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) {
- account := &gtsmodel.Account{}
+func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.id = ?", id)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) {
- account := &gtsmodel.Account{}
+func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) {
- account := &gtsmodel.Account{}
+func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.url = ?", uri)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) {
- account := &gtsmodel.Account{}
+func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
+ if strings.TrimSpace(account.ID) == "" {
+ return nil, errors.New("account had no ID")
+ }
+
+ account.UpdatedAt = time.Now()
+
+ q := a.conn.
+ NewUpdate().
+ Model(account).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return account, err
+}
+
+func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
q := a.newAccountQ(account)
@@ -90,29 +110,31 @@ func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Err
} else {
q = q.
Where("account.username = ?", domain).
- Where("? IS NULL", pg.Ident("domain"))
+ Where("? IS NULL", bun.Ident("domain"))
}
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) {
- status := &gtsmodel.Status{}
+func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) {
+ status := new(gtsmodel.Status)
- q := a.conn.Model(status).
+ q := a.conn.
+ NewSelect().
+ Model(status).
Order("id DESC").
Limit(1).
Where("account_id = ?", accountID).
Column("created_at")
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return status.CreatedAt, err
}
-func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
+func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
@@ -127,51 +149,66 @@ func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAtta
}
// TODO: there are probably more side effects here that need to be handled
- if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
+ if _, err := a.conn.
+ NewInsert().
+ Model(mediaAttachment).
+ Exec(ctx); err != nil {
return err
}
- if _, err := a.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
+ if _, err := a.conn.
+ NewUpdate().
+ Model(&gtsmodel.Account{}).
+ Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
+ Where("id = ?", accountID).
+ Exec(ctx); err != nil {
return err
}
return nil
}
-func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) {
- account := &gtsmodel.Account{}
+func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("username = ?", username).
- Where("? IS NULL", pg.Ident("domain"))
+ Where("? IS NULL", bun.Ident("domain"))
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return account, err
}
-func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) {
- faves := []*gtsmodel.StatusFave{}
+func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
+ faves := new([]*gtsmodel.StatusFave)
- if err := a.conn.Model(&faves).
+ if err := a.conn.
+ NewSelect().
+ Model(faves).
Where("account_id = ?", accountID).
- Select(); err != nil {
- if err == pg.ErrNoRows {
- return faves, nil
- }
+ Scan(ctx); err != nil {
return nil, err
}
- return faves, nil
+ return *faves, nil
}
-func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) {
- return a.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
+func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
+ return a.conn.
+ NewSelect().
+ Model(&gtsmodel.Status{}).
+ Where("account_id = ?", accountID).
+ Count(ctx)
}
-func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
+func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
- q := a.conn.Model(&statuses).Order("id DESC")
+ q := a.conn.
+ NewSelect().
+ Model(&statuses).
+ Order("id DESC")
+
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
@@ -181,27 +218,26 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli
}
if excludeReplies {
- q = q.Where("? IS NULL", pg.Ident("in_reply_to_id"))
+ q = q.Where("? IS NULL", bun.Ident("in_reply_to_id"))
}
if pinnedOnly {
q = q.Where("pinned = ?", true)
}
- if mediaOnly {
- q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) {
- return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil
- })
- }
-
if maxID != "" {
q = q.Where("id < ?", maxID)
}
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
+ if mediaOnly {
+ q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
+ return q.
+ WhereOr("? IS NOT NULL", bun.Ident("attachments")).
+ WhereOr("attachments != '{}'")
+ })
+ }
+
+ if err := q.Scan(ctx); err != nil {
return nil, err
}
@@ -213,10 +249,12 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli
return statuses, nil
}
-func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
+func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
blocks := []*gtsmodel.Block{}
- fq := a.conn.Model(&blocks).
+ fq := a.conn.
+ NewSelect().
+ Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
@@ -233,11 +271,8 @@ func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID str
fq = fq.Limit(limit)
}
- err := fq.Select()
+ err := fq.Scan(ctx)
if err != nil {
- if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries
- }
return nil, "", "", err
}
diff --git a/internal/db/pg/account_test.go b/internal/db/bundb/account_test.go
index 7ea5ff39a..7174b781d 100644
--- a/internal/db/pg/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -16,17 +16,19 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg_test
+package bundb_test
import (
+ "context"
"testing"
+ "time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type AccountTestSuite struct {
- PGStandardTestSuite
+ BunDBStandardTestSuite
}
func (suite *AccountTestSuite) SetupSuite() {
@@ -54,7 +56,7 @@ func (suite *AccountTestSuite) TearDownTest() {
}
func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
- account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID)
+ account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -65,6 +67,20 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}
+func (suite *AccountTestSuite) TestUpdateAccount() {
+ testAccount := suite.testAccounts["local_account_1"]
+
+ testAccount.DisplayName = "new display name!"
+
+ _, err := suite.db.UpdateAccount(context.Background(), testAccount)
+ suite.NoError(err)
+
+ updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID)
+ suite.NoError(err)
+ suite.Equal("new display name!", updated.DisplayName)
+ suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
+}
+
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}
diff --git a/internal/db/pg/admin.go b/internal/db/bundb/admin.go
index 854f56ef0..67a1e8a0d 100644
--- a/internal/db/pg/admin.go
+++ b/internal/db/bundb/admin.go
@@ -16,76 +16,76 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
"crypto/rand"
"crypto/rsa"
+ "database/sql"
"fmt"
"net"
"net/mail"
"strings"
"time"
- "github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
)
type adminDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (a *adminDB) IsUsernameAvailable(username string) db.Error {
- // if no error we fail because it means we found something
- // if error but it's not pg.ErrNoRows then we fail
- // if err is pg.ErrNoRows we're good, we found nothing so continue
- if err := a.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
- return fmt.Errorf("username %s already in use", username)
- } else if err != pg.ErrNoRows {
- return fmt.Errorf("db error: %s", err)
- }
- return nil
+func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
+ q := a.conn.
+ NewSelect().
+ Model(&gtsmodel.Account{}).
+ Where("username = ?", username).
+ Where("domain = ?", nil)
+
+ return notExists(ctx, q)
}
-func (a *adminDB) IsEmailAvailable(email string) db.Error {
+func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
- return fmt.Errorf("error parsing email address %s: %s", email, err)
+ return false, fmt.Errorf("error parsing email address %s: %s", email, err)
}
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
- if err := a.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
+ if err := a.conn.
+ NewSelect().
+ Model(&gtsmodel.EmailDomainBlock{}).
+ Where("domain = ?", domain).
+ Scan(ctx); err == nil {
// fail because we found something
- return fmt.Errorf("email domain %s is blocked", domain)
- } else if err != pg.ErrNoRows {
- // fail because we got an unexpected error
- return fmt.Errorf("db error: %s", err)
+ return false, fmt.Errorf("email domain %s is blocked", domain)
+ } else if err != sql.ErrNoRows {
+ return false, processErrorResponse(err)
}
// check if this email is associated with a user already
- if err := a.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
- // fail because we found something
- return fmt.Errorf("email %s already in use", email)
- } else if err != pg.ErrNoRows {
- // fail because we got an unexpected error
- return fmt.Errorf("db error: %s", err)
- }
- return nil
+ q := a.conn.
+ NewSelect().
+ Model(&gtsmodel.User{}).
+ Where("email = ?", email).
+ WhereOr("unconfirmed_email = ?", email)
+
+ return notExists(ctx, q)
}
-func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
+func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
@@ -94,13 +94,12 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool
// if something went wrong while creating a user, we might already have an account, so check here first...
acct := &gtsmodel.Account{}
- err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
+ err = a.conn.NewSelect().
+ Model(acct).
+ Where("username = ?", username).
+ Where("? IS NULL", bun.Ident("domain")).
+ Scan(ctx)
if err != nil {
- // there's been an actual error
- if err != pg.ErrNoRows {
- return nil, fmt.Errorf("db error checking existence of account: %s", err)
- }
-
// we just don't have an account yet create one
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
newAccountID, err := id.NewRandomULID()
@@ -125,7 +124,10 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
- if _, err = a.conn.Model(acct).Insert(); err != nil {
+ if _, err = a.conn.
+ NewInsert().
+ Model(acct).
+ Exec(ctx); err != nil {
return nil, err
}
}
@@ -161,15 +163,33 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool
u.Moderator = true
}
- if _, err = a.conn.Model(u).Insert(); err != nil {
+ if _, err = a.conn.
+ NewInsert().
+ Model(u).
+ Exec(ctx); err != nil {
return nil, err
}
return u, nil
}
-func (a *adminDB) CreateInstanceAccount() db.Error {
+func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
username := a.config.Host
+
+ // check if instance account already exists
+ existsQ := a.conn.
+ NewSelect().
+ Model(&gtsmodel.Account{}).
+ Where("username = ?", username).
+ Where("? IS NULL", bun.Ident("domain"))
+ count, err := existsQ.Count(ctx)
+ if err != nil && count == 1 {
+ a.log.Infof("instance account %s already exists", username)
+ return nil
+ } else if err != sql.ErrNoRows {
+ return processErrorResponse(err)
+ }
+
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
@@ -198,19 +218,36 @@ func (a *adminDB) CreateInstanceAccount() db.Error {
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
- inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert()
- if err != nil {
+
+ insertQ := a.conn.
+ NewInsert().
+ Model(acct)
+
+ if _, err := insertQ.Exec(ctx); err != nil {
return err
}
- if inserted {
- a.log.Infof("created instance account %s with id %s", username, acct.ID)
- } else {
- a.log.Infof("instance account %s already exists with id %s", username, acct.ID)
- }
+
+ a.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil
}
-func (a *adminDB) CreateInstanceInstance() db.Error {
+func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
+ domain := a.config.Host
+
+ // check if instance entry already exists
+ existsQ := a.conn.
+ NewSelect().
+ Model(&gtsmodel.Instance{}).
+ Where("domain = ?", domain)
+
+ count, err := existsQ.Count(ctx)
+ if err != nil && count == 1 {
+ a.log.Infof("instance instance %s already exists", domain)
+ return nil
+ } else if err != sql.ErrNoRows {
+ return processErrorResponse(err)
+ }
+
iID, err := id.NewRandomULID()
if err != nil {
return err
@@ -218,18 +255,18 @@ func (a *adminDB) CreateInstanceInstance() db.Error {
i := &gtsmodel.Instance{
ID: iID,
- Domain: a.config.Host,
- Title: a.config.Host,
+ Domain: domain,
+ Title: domain,
URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
}
- inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert()
- if err != nil {
+
+ insertQ := a.conn.
+ NewInsert().
+ Model(i)
+
+ if _, err := insertQ.Exec(ctx); err != nil {
return err
}
- if inserted {
- a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID)
- } else {
- a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID)
- }
+ a.log.Infof("created instance instance %s with id %s", domain, i.ID)
return nil
}
diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go
new file mode 100644
index 000000000..983b6b810
--- /dev/null
+++ b/internal/db/bundb/basic.go
@@ -0,0 +1,179 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "strings"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/uptrace/bun"
+)
+
+type basicDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewInsert().Model(i).Exec(ctx)
+ if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists
+ }
+ return err
+}
+
+func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewSelect().
+ Model(i).
+ Where("id = ?", id)
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := b.conn.NewSelect().Model(i)
+ for _, w := range where {
+
+ if w.Value == nil {
+ q = q.Where("? IS NULL", bun.Ident(w.Key))
+ } else {
+ if w.CaseInsensitive {
+ q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
+ } else {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+ }
+ }
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
+ q := b.conn.
+ NewSelect().
+ Model(i)
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewDelete().
+ Model(i).
+ Where("id = ?", id)
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := b.conn.
+ NewDelete().
+ Model(i)
+
+ for _, w := range where {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewUpdate().
+ Model(i).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
+ q := b.conn.NewUpdate().
+ Model(i).
+ Set("? = ?", bun.Safe(key), value).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
+ q := b.conn.NewUpdate().Model(i)
+
+ for _, w := range where {
+ if w.Value == nil {
+ q = q.Where("? IS NULL", bun.Ident(w.Key))
+ } else {
+ if w.CaseInsensitive {
+ q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
+ } else {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+ }
+ }
+
+ q = q.Set("? = ?", bun.Safe(key), value)
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx)
+ return err
+}
+
+func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
+ return b.conn.Ping()
+}
+
+func (b *basicDB) Stop(ctx context.Context) db.Error {
+ b.log.Info("closing db connection")
+ if err := b.conn.Close(); err != nil {
+ // only cancel if there's a problem closing the db
+ return err
+ }
+ return nil
+}
diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go
new file mode 100644
index 000000000..9189618c9
--- /dev/null
+++ b/internal/db/bundb/basic_test.go
@@ -0,0 +1,68 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type BasicTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *BasicTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *BasicTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *BasicTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *BasicTestSuite) TestGetAccountByID() {
+ testAccount := suite.testAccounts["local_account_1"]
+
+ a := &gtsmodel.Account{}
+ err := suite.db.GetByID(context.Background(), testAccount.ID, a)
+ suite.NoError(err)
+}
+
+func TestBasicTestSuite(t *testing.T) {
+ suite.Run(t, new(BasicTestSuite))
+}
diff --git a/internal/db/pg/pg.go b/internal/db/bundb/bundb.go
index 0437baf02..49ed09cbd 100644
--- a/internal/db/pg/pg.go
+++ b/internal/db/bundb/bundb.go
@@ -16,12 +16,13 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
"crypto/tls"
"crypto/x509"
+ "database/sql"
"encoding/pem"
"errors"
"fmt"
@@ -29,14 +30,20 @@ import (
"strings"
"time"
- "github.com/go-pg/pg/extra/pgdebug"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
+ "github.com/jackc/pgx/v4"
+ "github.com/jackc/pgx/v4/stdlib"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/uptrace/bun"
+ "github.com/uptrace/bun/dialect/pgdialect"
+)
+
+const (
+ dbTypePostgres = "postgres"
+ dbTypeSqlite = "sqlite"
)
var registerTables []interface{} = []interface{}{
@@ -44,8 +51,8 @@ var registerTables []interface{} = []interface{}{
&gtsmodel.StatusToTag{},
}
-// postgresService satisfies the DB interface
-type postgresService struct {
+// bunDBService satisfies the DB interface
+type bunDBService struct {
db.Account
db.Admin
db.Basic
@@ -55,130 +62,115 @@ type postgresService struct {
db.Mention
db.Notification
db.Relationship
+ db.Session
db.Status
db.Timeline
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
-// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
-func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
- for _, t := range registerTables {
- // https://pg.uptrace.dev/orm/many-to-many-relation/
- orm.RegisterTable(t)
- }
-
- opts, err := derivePGOptions(c)
- if err != nil {
- return nil, fmt.Errorf("could not create postgres service: %s", err)
- }
- log.Debugf("using pg options: %+v", opts)
-
- // create a connection
- pgCtx, cancel := context.WithCancel(ctx)
- conn := pg.Connect(opts).WithContext(pgCtx)
-
- // this will break the logfmt format we normally log in,
- // since we can't choose where pg outputs to and it defaults to
- // stdout. So use this option with care!
- if log.GetLevel() >= logrus.TraceLevel {
- conn.AddQueryHook(pgdebug.DebugHook{
- // Print all queries.
- Verbose: true,
- })
+// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
+// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
+func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
+ var sqldb *sql.DB
+ var conn *bun.DB
+
+ // depending on the database type we're trying to create, we need to use a different driver...
+ switch strings.ToLower(c.DBConfig.Type) {
+ case dbTypePostgres:
+ // POSTGRES
+ opts, err := deriveBunDBPGOptions(c)
+ if err != nil {
+ return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
+ }
+ sqldb = stdlib.OpenDB(*opts)
+ conn = bun.NewDB(sqldb, pgdialect.New())
+ case dbTypeSqlite:
+ // SQLITE
+ // TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
+ default:
+ return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
}
// actually *begin* the connection so that we can tell if the db is there and listening
- if err := conn.Ping(ctx); err != nil {
- cancel()
+ if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("db connection error: %s", err)
}
+ log.Info("connected to database")
- // print out discovered postgres version
- var version string
- if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil {
- cancel()
- return nil, fmt.Errorf("db connection error: %s", err)
+ for _, t := range registerTables {
+ // https://bun.uptrace.dev/orm/many-to-many-relation/
+ conn.RegisterModel(t)
}
- log.Infof("connected to postgres version: %s", version)
- ps := &postgresService{
+ ps := &bunDBService{
Account: &accountDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Admin: &adminDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Basic: &basicDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Domain: &domainDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Instance: &instanceDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Media: &mediaDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Mention: &mentionDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Notification: &notificationDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Relationship: &relationshipDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
+ },
+ Session: &sessionDB{
+ config: c,
+ conn: conn,
+ log: log,
},
Status: &statusDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
Timeline: &timelineDB{
config: c,
conn: conn,
log: log,
- cancel: cancel,
},
config: c,
conn: conn,
log: log,
- cancel: cancel,
}
- // we can confidently return this useable postgres service now
+ // we can confidently return this useable service now
return ps, nil
}
@@ -186,9 +178,9 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
HANDY STUFF
*/
-// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options
+// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config.
-func derivePGOptions(c *config.Config) (*pg.Options, error) {
+func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) {
if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
}
@@ -266,18 +258,16 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
tlsConfig.RootCAs = certPool
}
- // We can rely on the pg library we're using to set
- // sensible defaults for everything we don't set here.
- options := &pg.Options{
- Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port),
- User: c.DBConfig.User,
- Password: c.DBConfig.Password,
- Database: c.DBConfig.Database,
- ApplicationName: c.ApplicationName,
- TLSConfig: tlsConfig,
- }
+ cfg, _ := pgx.ParseConfig("")
+ cfg.Host = c.DBConfig.Address
+ cfg.Port = uint16(c.DBConfig.Port)
+ cfg.User = c.DBConfig.User
+ cfg.Password = c.DBConfig.Password
+ cfg.TLSConfig = tlsConfig
+ cfg.Database = c.DBConfig.Database
+ cfg.PreferSimpleProtocol = true
- return options, nil
+ return cfg, nil
}
/*
@@ -286,9 +276,9 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
// TODO: move these to the type converter, it's bananas that they're here and not there
-func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
+func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
ogAccount := &gtsmodel.Account{}
- if err := ps.conn.Model(ogAccount).Where("id = ?", originAccountID).Select(); err != nil {
+ if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil {
return nil, err
}
@@ -333,14 +323,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
// match username + account, case insensitive
if local {
// local user -- should have a null domain
- err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("? IS NULL", pg.Ident("domain")).Select()
+ err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx)
} else {
// remote user -- should have domain defined
- err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("LOWER(?) = LOWER(?)", pg.Ident("domain"), domain).Select()
+ err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx)
}
if err != nil {
- if err == pg.ErrNoRows {
+ if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as a mencho and carry on about our business
ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
continue
@@ -364,14 +354,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
return menchies, nil
}
-func (ps *postgresService) TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
+func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
newTags := []*gtsmodel.Tag{}
for _, t := range tags {
tag := &gtsmodel.Tag{}
// we can use selectorinsert here to create the new tag if it doesn't exist already
// inserted will be true if this is a new tag we just created
- if err := ps.conn.Model(tag).Where("LOWER(?) = LOWER(?)", pg.Ident("name"), t).Select(); err != nil {
- if err == pg.ErrNoRows {
+ if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
+ if err == sql.ErrNoRows {
// tag doesn't exist yet so populate it
newID, err := id.NewRandomULID()
if err != nil {
@@ -400,13 +390,13 @@ func (ps *postgresService) TagStringsToTags(tags []string, originAccountID strin
return newTags, nil
}
-func (ps *postgresService) EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
+func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
newEmojis := []*gtsmodel.Emoji{}
for _, e := range emojis {
emoji := &gtsmodel.Emoji{}
- err := ps.conn.Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Select()
+ err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx)
if err != nil {
- if err == pg.ErrNoRows {
+ if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as an emoji and carry on about our business
ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
continue
diff --git a/internal/db/pg/pg_test.go b/internal/db/bundb/bundb_test.go
index c1e10abdf..b789375af 100644
--- a/internal/db/pg/pg_test.go
+++ b/internal/db/bundb/bundb_test.go
@@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg_test
+package bundb_test
import (
"github.com/sirupsen/logrus"
@@ -27,7 +27,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-type PGStandardTestSuite struct {
+type BunDBStandardTestSuite struct {
// standard suite interfaces
suite.Suite
config *config.Config
diff --git a/internal/db/pg/domain.go b/internal/db/bundb/domain.go
index 4e9b2ab48..6aa2b8ffe 100644
--- a/internal/db/pg/domain.go
+++ b/internal/db/bundb/domain.go
@@ -16,48 +16,46 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
"net/url"
- "github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
)
type domainDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) {
+func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
if domain == "" {
return false, nil
}
- blocked, err := d.conn.
+ q := d.conn.
+ NewSelect().
Model(&gtsmodel.DomainBlock{}).
Where("LOWER(domain) = LOWER(?)", domain).
- Exists()
+ Limit(1)
- err = processErrorResponse(err)
-
- return blocked, err
+ return exists(ctx, q)
}
-func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) {
+func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
// filter out any doubles
uniqueDomains := util.UniqueStrings(domains)
for _, domain := range uniqueDomains {
- if blocked, err := d.IsDomainBlocked(domain); err != nil {
+ if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil {
return false, err
} else if blocked {
return blocked, nil
@@ -68,16 +66,16 @@ func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) {
return false, nil
}
-func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) {
+func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) {
domain := uri.Hostname()
- return d.IsDomainBlocked(domain)
+ return d.IsDomainBlocked(ctx, domain)
}
-func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) {
+func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) {
domains := []string{}
for _, uri := range uris {
domains = append(domains, uri.Hostname())
}
- return d.AreDomainsBlocked(domains)
+ return d.AreDomainsBlocked(ctx, domains)
}
diff --git a/internal/db/pg/instance.go b/internal/db/bundb/instance.go
index 968832ca5..f9364346e 100644
--- a/internal/db/pg/instance.go
+++ b/internal/db/bundb/instance.go
@@ -16,43 +16,50 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
- "github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type instanceDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) {
- q := i.conn.Model(&[]*gtsmodel.Account{})
+func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Account{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where the domain field is null
- q = q.Where("? IS NULL", pg.Ident("domain"))
+ q = q.Where("? IS NULL", bun.Ident("domain"))
} else {
q = q.Where("domain = ?", domain)
}
// don't count the instance account or suspended users
- q = q.Where("username != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
+ q = q.
+ Where("username != ?", domain).
+ Where("? IS NULL", bun.Ident("suspended_at"))
- return q.Count()
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
}
-func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) {
- q := i.conn.Model(&[]*gtsmodel.Status{})
+func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Status{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where local is true
@@ -63,30 +70,39 @@ func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) {
Where("account.domain = ?", domain)
}
- return q.Count()
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
}
-func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) {
- q := i.conn.Model(&[]*gtsmodel.Instance{})
+func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Instance{})
if domain == i.config.Host {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
- q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
+ q = q.Where("domain != ?", domain).Where("? IS NULL", bun.Ident("suspended_at"))
} else {
// TODO: implement federated domain counting properly for remote domains
return 0, nil
}
- return q.Count()
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
}
-func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
+func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
- q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
+ q := i.conn.NewSelect().
+ Model(&accounts).
+ Where("domain = ?", domain).
+ Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
@@ -96,17 +112,7 @@ func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int)
q = q.Limit(limit)
}
- err := q.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
- }
-
- if len(accounts) == 0 {
- return nil, db.ErrNoEntries
- }
+ err := processErrorResponse(q.Scan(ctx))
- return accounts, nil
+ return accounts, err
}
diff --git a/internal/db/pg/media.go b/internal/db/bundb/media.go
index 618030af3..04e55ca62 100644
--- a/internal/db/pg/media.go
+++ b/internal/db/bundb/media.go
@@ -16,38 +16,38 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type mediaDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (m *mediaDB) newMediaQ(i interface{}) *orm.Query {
- return m.conn.Model(i).
+func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {
+ return m.conn.
+ NewSelect().
+ Model(i).
Relation("Account")
}
-func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) {
+func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) {
attachment := &gtsmodel.MediaAttachment{}
q := m.newMediaQ(attachment).
Where("media_attachment.id = ?", id)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return attachment, err
}
diff --git a/internal/db/pg/mention.go b/internal/db/bundb/mention.go
index b31f07b67..a444f9b5f 100644
--- a/internal/db/pg/mention.go
+++ b/internal/db/bundb/mention.go
@@ -16,25 +16,23 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type mentionDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
cache cache.Cache
}
@@ -67,14 +65,16 @@ func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
return mention, true
}
-func (m *mentionDB) newMentionQ(i interface{}) *orm.Query {
- return m.conn.Model(i).
+func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
+ return m.conn.
+ NewSelect().
+ Model(i).
Relation("Status").
Relation("OriginAccount").
Relation("TargetAccount")
}
-func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
+func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
if mention, cached := m.mentionCached(id); cached {
return mention, nil
}
@@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
q := m.newMentionQ(mention).
Where("mention.id = ?", id)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
if err == nil && mention != nil {
m.cacheMention(id, mention)
@@ -93,11 +93,11 @@ func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
return mention, err
}
-func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) {
+func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
mentions := []*gtsmodel.Mention{}
for _, i := range ids {
- mention, err := m.GetMention(i)
+ mention, err := m.GetMention(ctx, i)
if err != nil {
return nil, processErrorResponse(err)
}
diff --git a/internal/db/pg/notification.go b/internal/db/bundb/notification.go
index 281a76d85..1c30837ec 100644
--- a/internal/db/pg/notification.go
+++ b/internal/db/bundb/notification.go
@@ -16,25 +16,23 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type notificationDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
cache cache.Cache
}
@@ -67,14 +65,16 @@ func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification,
return notification, true
}
-func (n *notificationDB) newNotificationQ(i interface{}) *orm.Query {
- return n.conn.Model(i).
+func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
+ return n.conn.
+ NewSelect().
+ Model(i).
Relation("OriginAccount").
Relation("TargetAccount").
Relation("Status")
}
-func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.Error) {
+func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
if notification, cached := n.notificationCached(id); cached {
return notification, nil
}
@@ -84,7 +84,7 @@ func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.
q := n.newNotificationQ(notification).
Where("notification.id = ?", id)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
if err == nil && notification != nil {
n.cacheNotification(id, notification)
@@ -93,10 +93,11 @@ func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.
return notification, err
}
-func (n *notificationDB) GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
+func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// begin by selecting just the IDs
notifIDs := []*gtsmodel.Notification{}
q := n.conn.
+ NewSelect().
Model(&notifIDs).
Column("id").
Where("target_account_id = ?", accountID).
@@ -114,7 +115,7 @@ func (n *notificationDB) GetNotifications(accountID string, limit int, maxID str
q = q.Limit(limit)
}
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
if err != nil {
return nil, err
}
@@ -123,7 +124,7 @@ func (n *notificationDB) GetNotifications(accountID string, limit int, maxID str
// reason for this is that for each notif, we can instead get it from our cache if it's cached
notifications := []*gtsmodel.Notification{}
for _, notifID := range notifIDs {
- notif, err := n.GetNotification(notifID.ID)
+ notif, err := n.GetNotification(ctx, notifID.ID)
errP := processErrorResponse(err)
if errP != nil {
return nil, errP
diff --git a/internal/db/pg/relationship.go b/internal/db/bundb/relationship.go
index 76bd50c76..ccc604baf 100644
--- a/internal/db/pg/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -16,44 +16,49 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
+ "database/sql"
"fmt"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type relationshipDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query {
- return r.conn.Model(block).
+func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {
+ return r.conn.
+ NewSelect().
+ Model(block).
Relation("Account").
Relation("TargetAccount")
}
-func (r *relationshipDB) newFollowQ(follow interface{}) *orm.Query {
- return r.conn.Model(follow).
+func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
+ return r.conn.
+ NewSelect().
+ Model(follow).
Relation("Account").
Relation("TargetAccount")
}
-func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
+func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
q := r.conn.
+ NewSelect().
Model(&gtsmodel.Block{}).
Where("account_id = ?", account1).
- Where("target_account_id = ?", account2)
+ Where("target_account_id = ?", account2).
+ Limit(1)
if eitherDirection {
q = q.
@@ -61,30 +66,36 @@ func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirec
Where("account_id = ?", account2)
}
- return q.Exists()
+ return exists(ctx, q)
}
-func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) {
+func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
block := &gtsmodel.Block{}
q := r.newBlockQ(block).
Where("block.account_id = ?", account1).
Where("block.target_account_id = ?", account2)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return block, err
}
-func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
+func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
rel := &gtsmodel.Relationship{
ID: targetAccount,
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
- if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
- if err != pg.ErrNoRows {
+ if err := r.conn.
+ NewSelect().
+ Model(follow).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Scan(ctx); err != nil {
+ if err != sql.ErrNoRows {
// a proper error
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
}
@@ -100,75 +111,101 @@ func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount
}
// check if the target account follows the requesting account
- followedBy, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
+ count, err := r.conn.
+ NewSelect().
+ Model(&gtsmodel.Follow{}).
+ Where("account_id = ?", targetAccount).
+ Where("target_account_id = ?", requestingAccount).
+ Limit(1).
+ Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
}
- rel.FollowedBy = followedBy
+ rel.FollowedBy = count > 0
// check if the requesting account blocks the target account
- blocking, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
+ count, err = r.conn.NewSelect().
+ Model(&gtsmodel.Block{}).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
- rel.Blocking = blocking
+ rel.Blocking = count > 0
// check if the target account blocks the requesting account
- blockedBy, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
+ count, err = r.conn.
+ NewSelect().
+ Model(&gtsmodel.Block{}).
+ Where("account_id = ?", targetAccount).
+ Where("target_account_id = ?", requestingAccount).
+ Limit(1).
+ Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
- rel.BlockedBy = blockedBy
+ rel.BlockedBy = count > 0
// check if there's a pending following request from requesting account to target account
- requested, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
+ count, err = r.conn.
+ NewSelect().
+ Model(&gtsmodel.FollowRequest{}).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
- rel.Requested = requested
+ rel.Requested = count > 0
return rel, nil
}
-func (r *relationshipDB) IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
+func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
+ NewSelect().
Model(&gtsmodel.Follow{}).
Where("account_id = ?", sourceAccount.ID).
- Where("target_account_id = ?", targetAccount.ID)
+ Where("target_account_id = ?", targetAccount.ID).
+ Limit(1)
- return q.Exists()
+ return exists(ctx, q)
}
-func (r *relationshipDB) IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
+func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
+ NewSelect().
Model(&gtsmodel.FollowRequest{}).
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
- return q.Exists()
+ return exists(ctx, q)
}
-func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
+func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
- f1, err := r.IsFollowing(account1, account2)
+ f1, err := r.IsFollowing(ctx, account1, account2)
if err != nil {
return false, processErrorResponse(err)
}
// make sure account 2 follows account 1
- f2, err := r.IsFollowing(account2, account1)
+ f2, err := r.IsFollowing(ctx, account2, account1)
if err != nil {
return false, processErrorResponse(err)
}
@@ -176,14 +213,16 @@ func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2
return f1 && f2, nil
}
-func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
+func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
- if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
- if err == pg.ErrMultiRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
+ if err := r.conn.
+ NewSelect().
+ Model(fr).
+ Where("account_id = ?", originAccountID).
+ Where("target_account_id = ?", targetAccountID).
+ Scan(ctx); err != nil {
+ return nil, processErrorResponse(err)
}
// create a new follow to 'replace' the request with
@@ -195,82 +234,95 @@ func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccou
}
// if the follow already exists, just update the URI -- we don't need to do anything else
- if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
- return nil, err
+ if _, err := r.conn.
+ NewInsert().
+ Model(follow).
+ On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
+ Exec(ctx); err != nil {
+ return nil, processErrorResponse(err)
}
// now remove the follow request
- if _, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
- return nil, err
+ if _, err := r.conn.
+ NewDelete().
+ Model(&gtsmodel.FollowRequest{}).
+ Where("account_id = ?", originAccountID).
+ Where("target_account_id = ?", targetAccountID).
+ Exec(ctx); err != nil {
+ return nil, processErrorResponse(err)
}
return follow, nil
}
-func (r *relationshipDB) GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
+func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
followRequests := []*gtsmodel.FollowRequest{}
q := r.newFollowQ(&followRequests).
Where("target_account_id = ?", accountID)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return followRequests, err
}
-func (r *relationshipDB) GetAccountFollows(accountID string) ([]*gtsmodel.Follow, db.Error) {
+func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.newFollowQ(&follows).
Where("account_id = ?", accountID)
- err := processErrorResponse(q.Select())
+ err := processErrorResponse(q.Scan(ctx))
return follows, err
}
-func (r *relationshipDB) CountAccountFollows(accountID string, localOnly bool) (int, db.Error) {
+func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
return r.conn.
+ NewSelect().
Model(&[]*gtsmodel.Follow{}).
Where("account_id = ?", accountID).
- Count()
+ Count(ctx)
}
-func (r *relationshipDB) GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
+func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
- q := r.conn.Model(&follows)
+ q := r.conn.
+ NewSelect().
+ Model(&follows)
if localOnly {
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
- whereGroup := func(q *pg.Query) (*pg.Query, error) {
+ whereGroup := func(q *bun.SelectQuery) *bun.SelectQuery {
q = q.
- WhereOr("? IS NULL", pg.Ident("a.domain")).
+ WhereOr("? IS NULL", bun.Ident("a.domain")).
WhereOr("a.domain = ?", "")
- return q, nil
+ return q
}
q = q.ColumnExpr("follow.*").
Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
Where("follow.target_account_id = ?", accountID).
- WhereGroup(whereGroup)
+ WhereGroup(" AND ", whereGroup)
} else {
q = q.Where("target_account_id = ?", accountID)
}
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
+ if err := q.Scan(ctx); err != nil {
+ if err == sql.ErrNoRows {
return follows, nil
}
- return nil, err
+ return nil, processErrorResponse(err)
}
return follows, nil
}
-func (r *relationshipDB) CountAccountFollowedBy(accountID string, localOnly bool) (int, db.Error) {
+func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
return r.conn.
+ NewSelect().
Model(&[]*gtsmodel.Follow{}).
Where("target_account_id = ?", accountID).
- Count()
+ Count(ctx)
}
diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go
new file mode 100644
index 000000000..87e20673d
--- /dev/null
+++ b/internal/db/bundb/session.go
@@ -0,0 +1,85 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "crypto/rand"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/uptrace/bun"
+)
+
+type sessionDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
+ rs := new(gtsmodel.RouterSession)
+
+ q := s.conn.
+ NewSelect().
+ Model(rs).
+ Limit(1)
+
+ _, err := q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return rs, err
+}
+
+func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
+ auth := make([]byte, 32)
+ crypt := make([]byte, 32)
+
+ if _, err := rand.Read(auth); err != nil {
+ return nil, err
+ }
+ if _, err := rand.Read(crypt); err != nil {
+ return nil, err
+ }
+
+ rid, err := id.NewULID()
+ if err != nil {
+ return nil, err
+ }
+
+ rs := &gtsmodel.RouterSession{
+ ID: rid,
+ Auth: auth,
+ Crypt: crypt,
+ }
+
+ q := s.conn.
+ NewInsert().
+ Model(rs)
+
+ _, err = q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return rs, err
+}
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
new file mode 100644
index 000000000..da8d8ca41
--- /dev/null
+++ b/internal/db/bundb/status.go
@@ -0,0 +1,375 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "container/list"
+ "context"
+ "errors"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type statusDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+ cache cache.Cache
+}
+
+func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
+ if s.cache == nil {
+ s.cache = cache.New()
+ }
+
+ if err := s.cache.Store(id, status); err != nil {
+ s.log.Panicf("statusDB: error storing in cache: %s", err)
+ }
+}
+
+func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
+ if s.cache == nil {
+ s.cache = cache.New()
+ return nil, false
+ }
+
+ sI, err := s.cache.Fetch(id)
+ if err != nil || sI == nil {
+ return nil, false
+ }
+
+ status, ok := sI.(*gtsmodel.Status)
+ if !ok {
+ s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
+ }
+
+ return status, true
+}
+
+func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
+ return s.conn.
+ NewSelect().
+ Model(status).
+ Relation("Attachments").
+ Relation("Tags").
+ Relation("Mentions").
+ Relation("Emojis").
+ Relation("Account").
+ Relation("InReplyToAccount").
+ Relation("BoostOfAccount").
+ Relation("CreatedWithApplication")
+}
+
+func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
+ if status.InReplyToID != "" && status.InReplyTo == nil {
+ if inReplyTo, cached := s.statusCached(status.InReplyToID); cached {
+ status.InReplyTo = inReplyTo
+ } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
+ status.InReplyTo = inReplyTo
+ }
+ }
+
+ if status.BoostOfID != "" && status.BoostOf == nil {
+ if boostOf, cached := s.statusCached(status.BoostOfID); cached {
+ status.BoostOf = boostOf
+ } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
+ status.BoostOf = boostOf
+ }
+ }
+
+ return status
+}
+
+func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
+ return s.conn.
+ NewSelect().
+ Model(faves).
+ Relation("Account").
+ Relation("TargetAccount").
+ Relation("Status")
+}
+
+func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(id); cached {
+ return status, nil
+ }
+
+ status := new(gtsmodel.Status)
+
+ q := s.newStatusQ(status).
+ Where("status.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(id, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(uri); cached {
+ return status, nil
+ }
+
+ status := &gtsmodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("LOWER(status.uri) = LOWER(?)", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(uri, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(uri); cached {
+ return status, nil
+ }
+
+ status := &gtsmodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("LOWER(status.url) = LOWER(?)", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(uri, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
+ transaction := func(ctx context.Context, tx bun.Tx) error {
+ // create links between this status and any emojis it uses
+ for _, i := range status.EmojiIDs {
+ if _, err := tx.NewInsert().Model(&gtsmodel.StatusToEmoji{
+ StatusID: status.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // create links between this status and any tags it uses
+ for _, i := range status.TagIDs {
+ if _, err := tx.NewInsert().Model(&gtsmodel.StatusToTag{
+ StatusID: status.ID,
+ TagID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // change the status ID of the media attachments to the new status
+ for _, a := range status.Attachments {
+ a.StatusID = status.ID
+ a.UpdatedAt = time.Now()
+ if _, err := s.conn.NewUpdate().Model(a).
+ Where("id = ?", a.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ _, err := tx.NewInsert().Model(status).Exec(ctx)
+ return err
+ }
+
+ return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
+}
+
+func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
+ parents := []*gtsmodel.Status{}
+ s.statusParent(ctx, status, &parents, onlyDirect)
+
+ return parents, nil
+}
+
+func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
+ if status.InReplyToID == "" {
+ return
+ }
+
+ parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID)
+ if err == nil {
+ *foundStatuses = append(*foundStatuses, parentStatus)
+ }
+
+ if onlyDirect {
+ return
+ }
+
+ s.statusParent(ctx, parentStatus, foundStatuses, false)
+}
+
+func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
+ foundStatuses := &list.List{}
+ foundStatuses.PushFront(status)
+ s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID)
+
+ children := []*gtsmodel.Status{}
+ for e := foundStatuses.Front(); e != nil; e = e.Next() {
+ entry, ok := e.Value.(*gtsmodel.Status)
+ if !ok {
+ panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
+ }
+
+ // only append children, not the overall parent status
+ if entry.ID != status.ID {
+ children = append(children, entry)
+ }
+ }
+
+ return children, nil
+}
+
+func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
+ immediateChildren := []*gtsmodel.Status{}
+
+ q := s.conn.
+ NewSelect().
+ Model(&immediateChildren).
+ Where("in_reply_to_id = ?", status.ID)
+ if minID != "" {
+ q = q.Where("status.id > ?", minID)
+ }
+
+ if err := q.Scan(ctx); err != nil {
+ return
+ }
+
+ for _, child := range immediateChildren {
+ insertLoop:
+ for e := foundStatuses.Front(); e != nil; e = e.Next() {
+ entry, ok := e.Value.(*gtsmodel.Status)
+ if !ok {
+ panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
+ }
+
+ if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
+ foundStatuses.InsertAfter(child, e)
+ break insertLoop
+ }
+ }
+
+ // only do one loop if we only want direct children
+ if onlyDirect {
+ return
+ }
+ s.statusChildren(ctx, child, foundStatuses, false, minID)
+ }
+}
+
+func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.StatusFave{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.Status{}).
+ Where("boost_of_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.StatusMute{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.StatusBookmark{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
+ faves := []*gtsmodel.StatusFave{}
+
+ q := s.newFaveQ(&faves).
+ Where("status_id = ?", status.ID)
+
+ err := processErrorResponse(q.Scan(ctx))
+ return faves, err
+}
+
+func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
+ reblogs := []*gtsmodel.Status{}
+
+ q := s.newStatusQ(&reblogs).
+ Where("boost_of_id = ?", status.ID)
+
+ err := processErrorResponse(q.Scan(ctx))
+ return reblogs, err
+}
diff --git a/internal/db/pg/status_test.go b/internal/db/bundb/status_test.go
index 8a185757c..513000577 100644
--- a/internal/db/pg/status_test.go
+++ b/internal/db/bundb/status_test.go
@@ -16,9 +16,10 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg_test
+package bundb_test
import (
+ "context"
"fmt"
"testing"
"time"
@@ -28,7 +29,7 @@ import (
)
type StatusTestSuite struct {
- PGStandardTestSuite
+ BunDBStandardTestSuite
}
func (suite *StatusTestSuite) SetupSuite() {
@@ -56,8 +57,9 @@ func (suite *StatusTestSuite) TearDownTest() {
}
func (suite *StatusTestSuite) TestGetStatusByID() {
- status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID)
if err != nil {
+ fmt.Println(err.Error())
suite.FailNow(err.Error())
}
suite.NotNil(status)
@@ -70,7 +72,7 @@ func (suite *StatusTestSuite) TestGetStatusByID() {
}
func (suite *StatusTestSuite) TestGetStatusByURI() {
- status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
+ status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
if err != nil {
suite.FailNow(err.Error())
}
@@ -84,7 +86,7 @@ func (suite *StatusTestSuite) TestGetStatusByURI() {
}
func (suite *StatusTestSuite) TestGetStatusWithExtras() {
- status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID)
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["admin_account_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -97,7 +99,7 @@ func (suite *StatusTestSuite) TestGetStatusWithExtras() {
}
func (suite *StatusTestSuite) TestGetStatusWithMention() {
- status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID)
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_2_status_5"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -112,18 +114,18 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {
func (suite *StatusTestSuite) TestGetStatusTwice() {
before1 := time.Now()
- _, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
+ _, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after1 := time.Now()
duration1 := after1.Sub(before1)
- fmt.Println(duration1.Nanoseconds())
+ fmt.Println(duration1.Milliseconds())
before2 := time.Now()
- _, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
+ _, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after2 := time.Now()
duration2 := after2.Sub(before2)
- fmt.Println(duration2.Nanoseconds())
+ fmt.Println(duration2.Milliseconds())
// second retrieval should be several orders faster since it will be cached now
suite.Less(duration2, duration1)
diff --git a/internal/db/pg/timeline.go b/internal/db/bundb/timeline.go
index fa8b07aab..b62ad4c50 100644
--- a/internal/db/pg/timeline.go
+++ b/internal/db/bundb/timeline.go
@@ -16,43 +16,35 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-package pg
+package bundb
import (
"context"
+ "database/sql"
"sort"
- "github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
)
type timelineDB struct {
config *config.Config
- conn *pg.DB
+ conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
+func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
- q := t.conn.Model(&statuses)
+ q := t.conn.
+ NewSelect().
+ Model(&statuses)
q = q.ColumnExpr("status.*").
// Find out who accountID follows.
Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id").
- // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
- // OR statuses posted by accountID itself (since a user should be able to see their own statuses).
- //
- // This is equivalent to something like WHERE ... AND (... OR ...)
- // See: https://pg.uptrace.dev/queries/#select
- WhereGroup(func(q *pg.Query) (*pg.Query, error) {
- q = q.WhereOr("f.account_id = ?", accountID).
- WhereOr("status.account_id = ?", accountID)
- return q, nil
- }).
// Sort by highest ID (newest) to lowest ID (oldest)
Order("status.id DESC")
@@ -81,29 +73,32 @@ func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID str
q = q.Limit(limit)
}
- err := q.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
+ // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
+ // OR statuses posted by accountID itself (since a user should be able to see their own statuses).
+ //
+ // This is equivalent to something like WHERE ... AND (... OR ...)
+ // See: https://bun.uptrace.dev/guide/queries.html#select
+ whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
+ return q.
+ WhereOr("f.account_id = ?", accountID).
+ WhereOr("status.account_id = ?", accountID)
}
- if len(statuses) == 0 {
- return nil, db.ErrNoEntries
- }
+ q = q.WhereGroup(" AND ", whereGroup)
- return statuses, nil
+ return statuses, processErrorResponse(q.Scan(ctx))
}
-func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
+func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
- q := t.conn.Model(&statuses).
+ q := t.conn.
+ NewSelect().
+ Model(&statuses).
Where("visibility = ?", gtsmodel.VisibilityPublic).
- Where("? IS NULL", pg.Ident("in_reply_to_id")).
- Where("? IS NULL", pg.Ident("in_reply_to_uri")).
- Where("? IS NULL", pg.Ident("boost_of_id")).
+ Where("? IS NULL", bun.Ident("in_reply_to_id")).
+ Where("? IS NULL", bun.Ident("in_reply_to_uri")).
+ Where("? IS NULL", bun.Ident("boost_of_id")).
Order("status.id DESC")
if maxID != "" {
@@ -126,28 +121,18 @@ func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID s
q = q.Limit(limit)
}
- err := q.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
- }
-
- if len(statuses) == 0 {
- return nil, db.ErrNoEntries
- }
-
- return statuses, nil
+ return statuses, processErrorResponse(q.Scan(ctx))
}
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
-func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
+func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
faves := []*gtsmodel.StatusFave{}
- fq := t.conn.Model(&faves).
+ fq := t.conn.
+ NewSelect().
+ Model(&faves).
Where("account_id = ?", accountID).
Order("id DESC")
@@ -163,9 +148,9 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri
fq = fq.Limit(limit)
}
- err := fq.Select()
+ err := fq.Scan(ctx)
if err != nil {
- if err == pg.ErrNoRows {
+ if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
@@ -185,9 +170,13 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri
}
statuses := []*gtsmodel.Status{}
- err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
+ err = t.conn.
+ NewSelect().
+ Model(&statuses).
+ Where("id IN (?)", bun.In(in)).
+ Scan(ctx)
if err != nil {
- if err == pg.ErrNoRows {
+ if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go
new file mode 100644
index 000000000..115d18de2
--- /dev/null
+++ b/internal/db/bundb/util.go
@@ -0,0 +1,78 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "strings"
+
+ "database/sql"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/uptrace/bun"
+)
+
+// processErrorResponse parses the given error and returns an appropriate DBError.
+func processErrorResponse(err error) db.Error {
+ switch err {
+ case nil:
+ return nil
+ case sql.ErrNoRows:
+ return db.ErrNoEntries
+ default:
+ if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists
+ }
+ return err
+ }
+}
+
+func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
+ count, err := q.Count(ctx)
+
+ exists := count != 0
+
+ err = processErrorResponse(err)
+
+ if err != nil {
+ if err == db.ErrNoEntries {
+ return false, nil
+ }
+ return false, err
+ }
+
+ return exists, nil
+}
+
+func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
+ count, err := q.Count(ctx)
+
+ notExists := count == 0
+
+ err = processErrorResponse(err)
+
+ if err != nil {
+ if err == db.ErrNoEntries {
+ return true, nil
+ }
+ return false, err
+ }
+
+ return notExists, nil
+}
diff --git a/internal/db/db.go b/internal/db/db.go
index d6ac883e4..ec94fcfe7 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -19,6 +19,8 @@
package db
import (
+ "context"
+
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -38,6 +40,7 @@ type DB interface {
Mention
Notification
Relationship
+ Session
Status
Timeline
@@ -52,7 +55,7 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the accounts in the DB, it's just for checking
// if they exist in the db and conveniently returning them if they do.
- MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error)
+ MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error)
// TagStringsToTags takes a slice of deduplicated, lowercase tags in the form "somehashtag", which have been
// used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then
@@ -61,7 +64,7 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the tags in the DB, it's just for checking
// if they exist in the db already, and conveniently returning them, or creating new tag structs.
- TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error)
+ TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error)
// EmojiStringsToEmojis takes a slice of deduplicated, lowercase emojis in the form ":emojiname:", which have been
// used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then
@@ -69,5 +72,5 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the emoji in the DB, it's just for checking
// if they exist in the db and conveniently returning them if they do.
- EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error)
+ EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error)
}
diff --git a/internal/db/domain.go b/internal/db/domain.go
index a6583c80c..df50a6770 100644
--- a/internal/db/domain.go
+++ b/internal/db/domain.go
@@ -18,19 +18,22 @@
package db
-import "net/url"
+import (
+ "context"
+ "net/url"
+)
// Domain contains DB functions related to domains and domain blocks.
type Domain interface {
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
- IsDomainBlocked(domain string) (bool, Error)
+ IsDomainBlocked(ctx context.Context, domain string) (bool, Error)
// AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
- AreDomainsBlocked(domains []string) (bool, Error)
+ AreDomainsBlocked(ctx context.Context, domains []string) (bool, Error)
// IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
- IsURIBlocked(uri *url.URL) (bool, Error)
+ IsURIBlocked(ctx context.Context, uri *url.URL) (bool, Error)
// AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
- AreURIsBlocked(uris []*url.URL) (bool, Error)
+ AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, Error)
}
diff --git a/internal/db/instance.go b/internal/db/instance.go
index 1f7c83e4f..dcd978a81 100644
--- a/internal/db/instance.go
+++ b/internal/db/instance.go
@@ -18,19 +18,23 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Instance contains functions for instance-level actions (counting instance users etc.).
type Instance interface {
// CountInstanceUsers returns the number of known accounts registered with the given domain.
- CountInstanceUsers(domain string) (int, Error)
+ CountInstanceUsers(ctx context.Context, domain string) (int, Error)
// CountInstanceStatuses returns the number of known statuses posted from the given domain.
- CountInstanceStatuses(domain string) (int, Error)
+ CountInstanceStatuses(ctx context.Context, domain string) (int, Error)
// CountInstanceDomains returns the number of known instances known that the given domain federates with.
- CountInstanceDomains(domain string) (int, Error)
+ CountInstanceDomains(ctx context.Context, domain string) (int, Error)
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
- GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
+ GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
}
diff --git a/internal/db/media.go b/internal/db/media.go
index db4db3411..b779dd276 100644
--- a/internal/db/media.go
+++ b/internal/db/media.go
@@ -18,10 +18,14 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Media contains functions related to creating/getting/removing media attachments.
type Media interface {
// GetAttachmentByID gets a single attachment by its ID
- GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error)
+ GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error)
}
diff --git a/internal/db/mention.go b/internal/db/mention.go
index cb1c56dc1..b9b45546a 100644
--- a/internal/db/mention.go
+++ b/internal/db/mention.go
@@ -18,13 +18,17 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Mention contains functions for getting/creating mentions in the database.
type Mention interface {
// GetMention gets a single mention by ID
- GetMention(id string) (*gtsmodel.Mention, Error)
+ GetMention(ctx context.Context, id string) (*gtsmodel.Mention, Error)
// GetMentions gets multiple mentions.
- GetMentions(ids []string) ([]*gtsmodel.Mention, Error)
+ GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error)
}
diff --git a/internal/db/notification.go b/internal/db/notification.go
index 326f0f149..09c17f031 100644
--- a/internal/db/notification.go
+++ b/internal/db/notification.go
@@ -18,14 +18,18 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Notification contains functions for creating and getting notifications.
type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
- GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
+ GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
// GetNotification returns one notification according to its id.
- GetNotification(id string) (*gtsmodel.Notification, Error)
+ GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error)
}
diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go
deleted file mode 100644
index 6e76b4450..000000000
--- a/internal/db/pg/basic.go
+++ /dev/null
@@ -1,205 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package pg
-
-import (
- "context"
- "errors"
- "fmt"
- "strings"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
-)
-
-type basicDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
-}
-
-func (b *basicDB) Put(i interface{}) db.Error {
- _, err := b.conn.Model(i).Insert(i)
- if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
- return db.ErrAlreadyExists
- }
- return err
-}
-
-func (b *basicDB) GetByID(id string, i interface{}) db.Error {
- if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
-
- }
- return nil
-}
-
-func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error {
- if len(where) == 0 {
- return errors.New("no queries provided")
- }
-
- q := b.conn.Model(i)
- for _, w := range where {
-
- if w.Value == nil {
- q = q.Where("? IS NULL", pg.Ident(w.Key))
- } else {
- if w.CaseInsensitive {
- q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
- } else {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
- }
- }
-
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) GetAll(i interface{}) db.Error {
- if err := b.conn.Model(i).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) DeleteByID(id string, i interface{}) db.Error {
- if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
- // if there are no rows *anyway* then that's fine
- // just return err if there's an actual error
- if err != pg.ErrNoRows {
- return err
- }
- }
- return nil
-}
-
-func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error {
- if len(where) == 0 {
- return errors.New("no queries provided")
- }
-
- q := b.conn.Model(i)
- for _, w := range where {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
-
- if _, err := q.Delete(); err != nil {
- // if there are no rows *anyway* then that's fine
- // just return err if there's an actual error
- if err != pg.ErrNoRows {
- return err
- }
- }
- return nil
-}
-
-func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error {
- if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) UpdateByID(id string, i interface{}) db.Error {
- if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error {
- _, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
- return err
-}
-
-func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error {
- q := b.conn.Model(i)
-
- for _, w := range where {
- if w.Value == nil {
- q = q.Where("? IS NULL", pg.Ident(w.Key))
- } else {
- if w.CaseInsensitive {
- q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
- } else {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
- }
- }
-
- q = q.Set("? = ?", pg.Safe(key), value)
-
- _, err := q.Update()
-
- return err
-}
-
-func (b *basicDB) CreateTable(i interface{}) db.Error {
- return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- })
-}
-
-func (b *basicDB) DropTable(i interface{}) db.Error {
- return b.conn.Model(i).DropTable(&orm.DropTableOptions{
- IfExists: true,
- })
-}
-
-func (b *basicDB) RegisterTable(i interface{}) db.Error {
- orm.RegisterTable(i)
- return nil
-}
-
-func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
- return b.conn.Ping(ctx)
-}
-
-func (b *basicDB) Stop(ctx context.Context) db.Error {
- b.log.Info("closing db connection")
- if err := b.conn.Close(); err != nil {
- // only cancel if there's a problem closing the db
- b.cancel()
- return err
- }
- return nil
-}
diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go
deleted file mode 100644
index 99790428e..000000000
--- a/internal/db/pg/status.go
+++ /dev/null
@@ -1,318 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package pg
-
-import (
- "container/list"
- "context"
- "errors"
- "time"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-type statusDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
- cache cache.Cache
-}
-
-func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
- if s.cache == nil {
- s.cache = cache.New()
- }
-
- if err := s.cache.Store(id, status); err != nil {
- s.log.Panicf("statusDB: error storing in cache: %s", err)
- }
-}
-
-func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
- if s.cache == nil {
- s.cache = cache.New()
- return nil, false
- }
-
- sI, err := s.cache.Fetch(id)
- if err != nil || sI == nil {
- return nil, false
- }
-
- status, ok := sI.(*gtsmodel.Status)
- if !ok {
- s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
- }
-
- return status, true
-}
-
-func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
- return s.conn.Model(status).
- Relation("Attachments").
- Relation("Tags").
- Relation("Mentions").
- Relation("Emojis").
- Relation("Account").
- Relation("InReplyTo").
- Relation("InReplyToAccount").
- Relation("BoostOf").
- Relation("BoostOfAccount").
- Relation("CreatedWithApplication")
-}
-
-func (s *statusDB) newFaveQ(faves interface{}) *orm.Query {
- return s.conn.Model(faves).
- Relation("Account").
- Relation("TargetAccount").
- Relation("Status")
-}
-
-func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(id); cached {
- return status, nil
- }
-
- status := &gtsmodel.Status{}
-
- q := s.newStatusQ(status).
- Where("status.id = ?", id)
-
- err := processErrorResponse(q.Select())
-
- if err == nil && status != nil {
- s.cacheStatus(id, status)
- }
-
- return status, err
-}
-
-func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(uri); cached {
- return status, nil
- }
-
- status := &gtsmodel.Status{}
-
- q := s.newStatusQ(status).
- Where("LOWER(status.uri) = LOWER(?)", uri)
-
- err := processErrorResponse(q.Select())
-
- if err == nil && status != nil {
- s.cacheStatus(uri, status)
- }
-
- return status, err
-}
-
-func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(uri); cached {
- return status, nil
- }
-
- status := &gtsmodel.Status{}
-
- q := s.newStatusQ(status).
- Where("LOWER(status.url) = LOWER(?)", uri)
-
- err := processErrorResponse(q.Select())
-
- if err == nil && status != nil {
- s.cacheStatus(uri, status)
- }
-
- return status, err
-}
-
-func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
- transaction := func(tx *pg.Tx) error {
- // create links between this status and any emojis it uses
- for _, i := range status.EmojiIDs {
- if _, err := tx.Model(&gtsmodel.StatusToEmoji{
- StatusID: status.ID,
- EmojiID: i,
- }).Insert(); err != nil {
- return err
- }
- }
-
- // create links between this status and any tags it uses
- for _, i := range status.TagIDs {
- if _, err := tx.Model(&gtsmodel.StatusToTag{
- StatusID: status.ID,
- TagID: i,
- }).Insert(); err != nil {
- return err
- }
- }
-
- // change the status ID of the media attachments to the new status
- for _, a := range status.Attachments {
- a.StatusID = status.ID
- a.UpdatedAt = time.Now()
- if _, err := s.conn.Model(a).
- Where("id = ?", a.ID).
- Update(); err != nil {
- return err
- }
- }
-
- _, err := tx.Model(status).Insert()
- return err
- }
-
- return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction))
-}
-
-func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
- parents := []*gtsmodel.Status{}
- s.statusParent(status, &parents, onlyDirect)
-
- return parents, nil
-}
-
-func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
- if status.InReplyToID == "" {
- return
- }
-
- parentStatus, err := s.GetStatusByID(status.InReplyToID)
- if err == nil {
- *foundStatuses = append(*foundStatuses, parentStatus)
- }
-
- if onlyDirect {
- return
- }
-
- s.statusParent(parentStatus, foundStatuses, false)
-}
-
-func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
- foundStatuses := &list.List{}
- foundStatuses.PushFront(status)
- s.statusChildren(status, foundStatuses, onlyDirect, minID)
-
- children := []*gtsmodel.Status{}
- for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
- // only append children, not the overall parent status
- if entry.ID != status.ID {
- children = append(children, entry)
- }
- }
-
- return children, nil
-}
-
-func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
- immediateChildren := []*gtsmodel.Status{}
-
- q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
- if minID != "" {
- q = q.Where("status.id > ?", minID)
- }
-
- if err := q.Select(); err != nil {
- return
- }
-
- for _, child := range immediateChildren {
- insertLoop:
- for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
- if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
- foundStatuses.InsertAfter(child, e)
- break insertLoop
- }
- }
-
- // only do one loop if we only want direct children
- if onlyDirect {
- return
- }
- s.statusChildren(child, foundStatuses, false, minID)
- }
-}
-
-func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) {
- return s.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
-}
-
-func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) {
- return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
-}
-
-func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) {
- return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
-}
-
-func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
- faves := []*gtsmodel.StatusFave{}
-
- q := s.newFaveQ(&faves).
- Where("status_id = ?", status.ID)
-
- err := processErrorResponse(q.Select())
-
- return faves, err
-}
-
-func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
- reblogs := []*gtsmodel.Status{}
-
- q := s.newStatusQ(&reblogs).
- Where("boost_of_id = ?", status.ID)
-
- err := processErrorResponse(q.Select())
-
- return reblogs, err
-}
diff --git a/internal/db/pg/util.go b/internal/db/pg/util.go
deleted file mode 100644
index 17c09b720..000000000
--- a/internal/db/pg/util.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package pg
-
-import (
- "strings"
-
- "github.com/go-pg/pg/v10"
- "github.com/superseriousbusiness/gotosocial/internal/db"
-)
-
-// processErrorResponse parses the given error and returns an appropriate DBError.
-func processErrorResponse(err error) db.Error {
- switch err {
- case nil:
- return nil
- case pg.ErrNoRows:
- return db.ErrNoEntries
- case pg.ErrMultiRows:
- return db.ErrMultipleEntries
- default:
- if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
- return db.ErrAlreadyExists
- }
- return err
- }
-}
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index 85f64d72b..804526425 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -18,54 +18,58 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
// IsBlocked checks whether account 1 has a block in place against block2.
// If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
- IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error)
+ IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
//
// Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
// not if you're just checking for the existence of a block.
- GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error)
+ GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
- GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
+ GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
- IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
- IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
- IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
+ IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
- AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
+ AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
// GetAccountFollowRequests returns all follow requests targeting the given account.
- GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error)
+ GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, Error)
// GetAccountFollows returns a slice of follows owned by the given accountID.
- GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error)
+ GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, Error)
// CountAccountFollows returns the amount of accounts that the given accountID is following.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
- CountAccountFollows(accountID string, localOnly bool) (int, Error)
+ CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, Error)
// GetAccountFollowedBy fetches follows that target given accountID.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
- GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
+ GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
// CountAccountFollowedBy returns the amounts that the given ID is followed by.
- CountAccountFollowedBy(accountID string, localOnly bool) (int, Error)
+ CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, Error)
}
diff --git a/internal/db/session.go b/internal/db/session.go
new file mode 100644
index 000000000..ae13dccce
--- /dev/null
+++ b/internal/db/session.go
@@ -0,0 +1,31 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+// Session handles getting/creation of router sessions.
+type Session interface {
+ GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
+ CreateSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
+}
diff --git a/internal/db/status.go b/internal/db/status.go
index 9d206c198..7430433c4 100644
--- a/internal/db/status.go
+++ b/internal/db/status.go
@@ -18,58 +18,62 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface {
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
- GetStatusByID(id string) (*gtsmodel.Status, Error)
+ GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error)
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
- GetStatusByURI(uri string) (*gtsmodel.Status, Error)
+ GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
- GetStatusByURL(uri string) (*gtsmodel.Status, Error)
+ GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// PutStatus stores one status in the database.
- PutStatus(status *gtsmodel.Status) Error
+ PutStatus(ctx context.Context, status *gtsmodel.Status) Error
// CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong
- CountStatusReplies(status *gtsmodel.Status) (int, Error)
+ CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, Error)
// CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
- CountStatusReblogs(status *gtsmodel.Status) (int, Error)
+ CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, Error)
// CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong
- CountStatusFaves(status *gtsmodel.Status) (int, Error)
+ CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, Error)
// GetStatusParents gets the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
- GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
+ GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
// GetStatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
- GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
+ GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
// IsStatusFavedBy checks if a given status has been faved by a given account ID
- IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
- IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusMutedBy checks if a given status has been muted by a given account ID
- IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
- IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
+ GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
// GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
+ GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
}
diff --git a/internal/db/timeline.go b/internal/db/timeline.go
index 74aa5c781..83fb3a959 100644
--- a/internal/db/timeline.go
+++ b/internal/db/timeline.go
@@ -18,20 +18,24 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Timeline contains functionality for retrieving home/public/faved etc timelines for an account.
type Timeline interface {
// GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
- GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
+ GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
- GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
+ GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
@@ -40,5 +44,5 @@ type Timeline interface {
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
- GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
+ GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
}