summaryrefslogtreecommitdiff
path: root/internal/db/bundb
diff options
context:
space:
mode:
authorLibravatar tobi <31960611+tsmethurst@users.noreply.github.com>2021-08-25 15:34:33 +0200
committerLibravatar GitHub <noreply@github.com>2021-08-25 15:34:33 +0200
commit2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch)
tree4ddeac479b923db38090aac8bd9209f3646851c1 /internal/db/bundb
parentManually approves followers (#146) (diff)
downloadgotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz
Pg to bun (#148)
* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
Diffstat (limited to 'internal/db/bundb')
-rw-r--r--internal/db/bundb/account.go291
-rw-r--r--internal/db/bundb/account_test.go86
-rw-r--r--internal/db/bundb/admin.go272
-rw-r--r--internal/db/bundb/basic.go179
-rw-r--r--internal/db/bundb/basic_test.go68
-rw-r--r--internal/db/bundb/bundb.go410
-rw-r--r--internal/db/bundb/bundb_test.go47
-rw-r--r--internal/db/bundb/domain.go81
-rw-r--r--internal/db/bundb/instance.go118
-rw-r--r--internal/db/bundb/media.go53
-rw-r--r--internal/db/bundb/mention.go108
-rw-r--r--internal/db/bundb/notification.go136
-rw-r--r--internal/db/bundb/relationship.go328
-rw-r--r--internal/db/bundb/session.go85
-rw-r--r--internal/db/bundb/status.go375
-rw-r--r--internal/db/bundb/status_test.go136
-rw-r--r--internal/db/bundb/timeline.go199
-rw-r--r--internal/db/bundb/util.go78
18 files changed, 3050 insertions, 0 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
new file mode 100644
index 000000000..7ebb79a15
--- /dev/null
+++ b/internal/db/bundb/account.go
@@ -0,0 +1,291 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type accountDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
+ return a.conn.
+ NewSelect().
+ Model(account).
+ Relation("AvatarMediaAttachment").
+ Relation("HeaderMediaAttachment")
+}
+
+func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
+
+ q := a.newAccountQ(account).
+ Where("account.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
+
+ q := a.newAccountQ(account).
+ Where("account.uri = ?", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
+
+ q := a.newAccountQ(account).
+ Where("account.url = ?", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return account, err
+}
+
+func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
+ if strings.TrimSpace(account.ID) == "" {
+ return nil, errors.New("account had no ID")
+ }
+
+ account.UpdatedAt = time.Now()
+
+ q := a.conn.
+ NewUpdate().
+ Model(account).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return account, err
+}
+
+func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
+
+ q := a.newAccountQ(account)
+
+ if domain == "" {
+ q = q.
+ Where("account.username = ?", domain).
+ Where("account.domain = ?", domain)
+ } else {
+ q = q.
+ Where("account.username = ?", domain).
+ Where("? IS NULL", bun.Ident("domain"))
+ }
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) {
+ status := new(gtsmodel.Status)
+
+ q := a.conn.
+ NewSelect().
+ Model(status).
+ Order("id DESC").
+ Limit(1).
+ Where("account_id = ?", accountID).
+ Column("created_at")
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return status.CreatedAt, err
+}
+
+func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
+ if mediaAttachment.Avatar && mediaAttachment.Header {
+ return errors.New("one media attachment cannot be both header and avatar")
+ }
+
+ var headerOrAVI string
+ if mediaAttachment.Avatar {
+ headerOrAVI = "avatar"
+ } else if mediaAttachment.Header {
+ headerOrAVI = "header"
+ } else {
+ return errors.New("given media attachment was neither a header nor an avatar")
+ }
+
+ // TODO: there are probably more side effects here that need to be handled
+ if _, err := a.conn.
+ NewInsert().
+ Model(mediaAttachment).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ if _, err := a.conn.
+ NewUpdate().
+ Model(&gtsmodel.Account{}).
+ Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
+ Where("id = ?", accountID).
+ Exec(ctx); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
+
+ q := a.newAccountQ(account).
+ Where("username = ?", username).
+ Where("? IS NULL", bun.Ident("domain"))
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
+ faves := new([]*gtsmodel.StatusFave)
+
+ if err := a.conn.
+ NewSelect().
+ Model(faves).
+ Where("account_id = ?", accountID).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return *faves, nil
+}
+
+func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
+ return a.conn.
+ NewSelect().
+ Model(&gtsmodel.Status{}).
+ Where("account_id = ?", accountID).
+ Count(ctx)
+}
+
+func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
+ a.log.Debugf("getting statuses for account %s", accountID)
+ statuses := []*gtsmodel.Status{}
+
+ q := a.conn.
+ NewSelect().
+ Model(&statuses).
+ Order("id DESC")
+
+ if accountID != "" {
+ q = q.Where("account_id = ?", accountID)
+ }
+
+ if limit != 0 {
+ q = q.Limit(limit)
+ }
+
+ if excludeReplies {
+ q = q.Where("? IS NULL", bun.Ident("in_reply_to_id"))
+ }
+
+ if pinnedOnly {
+ q = q.Where("pinned = ?", true)
+ }
+
+ if maxID != "" {
+ q = q.Where("id < ?", maxID)
+ }
+
+ if mediaOnly {
+ q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
+ return q.
+ WhereOr("? IS NOT NULL", bun.Ident("attachments")).
+ WhereOr("attachments != '{}'")
+ })
+ }
+
+ if err := q.Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ if len(statuses) == 0 {
+ return nil, db.ErrNoEntries
+ }
+
+ a.log.Debugf("returning statuses for account %s", accountID)
+ return statuses, nil
+}
+
+func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
+ blocks := []*gtsmodel.Block{}
+
+ fq := a.conn.
+ NewSelect().
+ Model(&blocks).
+ Where("block.account_id = ?", accountID).
+ Relation("TargetAccount").
+ Order("block.id DESC")
+
+ if maxID != "" {
+ fq = fq.Where("block.id < ?", maxID)
+ }
+
+ if sinceID != "" {
+ fq = fq.Where("block.id > ?", sinceID)
+ }
+
+ if limit > 0 {
+ fq = fq.Limit(limit)
+ }
+
+ err := fq.Scan(ctx)
+ if err != nil {
+ return nil, "", "", err
+ }
+
+ if len(blocks) == 0 {
+ return nil, "", "", db.ErrNoEntries
+ }
+
+ accounts := []*gtsmodel.Account{}
+ for _, b := range blocks {
+ accounts = append(accounts, b.TargetAccount)
+ }
+
+ nextMaxID := blocks[len(blocks)-1].ID
+ prevMinID := blocks[0].ID
+ return accounts, nextMaxID, prevMinID, nil
+}
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
new file mode 100644
index 000000000..7174b781d
--- /dev/null
+++ b/internal/db/bundb/account_test.go
@@ -0,0 +1,86 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type AccountTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *AccountTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *AccountTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *AccountTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
+ account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(account)
+ suite.NotNil(account.AvatarMediaAttachment)
+ suite.NotEmpty(account.AvatarMediaAttachment.URL)
+ suite.NotNil(account.HeaderMediaAttachment)
+ suite.NotEmpty(account.HeaderMediaAttachment.URL)
+}
+
+func (suite *AccountTestSuite) TestUpdateAccount() {
+ testAccount := suite.testAccounts["local_account_1"]
+
+ testAccount.DisplayName = "new display name!"
+
+ _, err := suite.db.UpdateAccount(context.Background(), testAccount)
+ suite.NoError(err)
+
+ updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID)
+ suite.NoError(err)
+ suite.Equal("new display name!", updated.DisplayName)
+ suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
+}
+
+func TestAccountTestSuite(t *testing.T) {
+ suite.Run(t, new(AccountTestSuite))
+}
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go
new file mode 100644
index 000000000..67a1e8a0d
--- /dev/null
+++ b/internal/db/bundb/admin.go
@@ -0,0 +1,272 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "database/sql"
+ "fmt"
+ "net"
+ "net/mail"
+ "strings"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
+ "golang.org/x/crypto/bcrypt"
+)
+
+type adminDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
+ q := a.conn.
+ NewSelect().
+ Model(&gtsmodel.Account{}).
+ Where("username = ?", username).
+ Where("domain = ?", nil)
+
+ return notExists(ctx, q)
+}
+
+func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
+ // parse the domain from the email
+ m, err := mail.ParseAddress(email)
+ if err != nil {
+ return false, fmt.Errorf("error parsing email address %s: %s", email, err)
+ }
+ domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
+
+ // check if the email domain is blocked
+ if err := a.conn.
+ NewSelect().
+ Model(&gtsmodel.EmailDomainBlock{}).
+ Where("domain = ?", domain).
+ Scan(ctx); err == nil {
+ // fail because we found something
+ return false, fmt.Errorf("email domain %s is blocked", domain)
+ } else if err != sql.ErrNoRows {
+ return false, processErrorResponse(err)
+ }
+
+ // check if this email is associated with a user already
+ q := a.conn.
+ NewSelect().
+ Model(&gtsmodel.User{}).
+ Where("email = ?", email).
+ WhereOr("unconfirmed_email = ?", email)
+
+ return notExists(ctx, q)
+}
+
+func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ a.log.Errorf("error creating new rsa key: %s", err)
+ return nil, err
+ }
+
+ // if something went wrong while creating a user, we might already have an account, so check here first...
+ acct := &gtsmodel.Account{}
+ err = a.conn.NewSelect().
+ Model(acct).
+ Where("username = ?", username).
+ Where("? IS NULL", bun.Ident("domain")).
+ Scan(ctx)
+ if err != nil {
+ // we just don't have an account yet create one
+ newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
+ newAccountID, err := id.NewRandomULID()
+ if err != nil {
+ return nil, err
+ }
+
+ acct = &gtsmodel.Account{
+ ID: newAccountID,
+ Username: username,
+ DisplayName: username,
+ Reason: reason,
+ URL: newAccountURIs.UserURL,
+ PrivateKey: key,
+ PublicKey: &key.PublicKey,
+ PublicKeyURI: newAccountURIs.PublicKeyURI,
+ ActorType: gtsmodel.ActivityStreamsPerson,
+ URI: newAccountURIs.UserURI,
+ InboxURI: newAccountURIs.InboxURI,
+ OutboxURI: newAccountURIs.OutboxURI,
+ FollowersURI: newAccountURIs.FollowersURI,
+ FollowingURI: newAccountURIs.FollowingURI,
+ FeaturedCollectionURI: newAccountURIs.CollectionURI,
+ }
+ if _, err = a.conn.
+ NewInsert().
+ Model(acct).
+ Exec(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
+ if err != nil {
+ return nil, fmt.Errorf("error hashing password: %s", err)
+ }
+
+ newUserID, err := id.NewRandomULID()
+ if err != nil {
+ return nil, err
+ }
+
+ u := &gtsmodel.User{
+ ID: newUserID,
+ AccountID: acct.ID,
+ EncryptedPassword: string(pw),
+ SignUpIP: signUpIP.To4(),
+ Locale: locale,
+ UnconfirmedEmail: email,
+ CreatedByApplicationID: appID,
+ Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
+ }
+
+ if emailVerified {
+ u.ConfirmedAt = time.Now()
+ u.Email = email
+ }
+
+ if admin {
+ u.Admin = true
+ u.Moderator = true
+ }
+
+ if _, err = a.conn.
+ NewInsert().
+ Model(u).
+ Exec(ctx); err != nil {
+ return nil, err
+ }
+
+ return u, nil
+}
+
+func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
+ username := a.config.Host
+
+ // check if instance account already exists
+ existsQ := a.conn.
+ NewSelect().
+ Model(&gtsmodel.Account{}).
+ Where("username = ?", username).
+ Where("? IS NULL", bun.Ident("domain"))
+ count, err := existsQ.Count(ctx)
+ if err != nil && count == 1 {
+ a.log.Infof("instance account %s already exists", username)
+ return nil
+ } else if err != sql.ErrNoRows {
+ return processErrorResponse(err)
+ }
+
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ a.log.Errorf("error creating new rsa key: %s", err)
+ return err
+ }
+
+ aID, err := id.NewRandomULID()
+ if err != nil {
+ return err
+ }
+
+ newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
+ acct := &gtsmodel.Account{
+ ID: aID,
+ Username: a.config.Host,
+ DisplayName: username,
+ URL: newAccountURIs.UserURL,
+ PrivateKey: key,
+ PublicKey: &key.PublicKey,
+ PublicKeyURI: newAccountURIs.PublicKeyURI,
+ ActorType: gtsmodel.ActivityStreamsPerson,
+ URI: newAccountURIs.UserURI,
+ InboxURI: newAccountURIs.InboxURI,
+ OutboxURI: newAccountURIs.OutboxURI,
+ FollowersURI: newAccountURIs.FollowersURI,
+ FollowingURI: newAccountURIs.FollowingURI,
+ FeaturedCollectionURI: newAccountURIs.CollectionURI,
+ }
+
+ insertQ := a.conn.
+ NewInsert().
+ Model(acct)
+
+ if _, err := insertQ.Exec(ctx); err != nil {
+ return err
+ }
+
+ a.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
+ return nil
+}
+
+func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
+ domain := a.config.Host
+
+ // check if instance entry already exists
+ existsQ := a.conn.
+ NewSelect().
+ Model(&gtsmodel.Instance{}).
+ Where("domain = ?", domain)
+
+ count, err := existsQ.Count(ctx)
+ if err != nil && count == 1 {
+ a.log.Infof("instance instance %s already exists", domain)
+ return nil
+ } else if err != sql.ErrNoRows {
+ return processErrorResponse(err)
+ }
+
+ iID, err := id.NewRandomULID()
+ if err != nil {
+ return err
+ }
+
+ i := &gtsmodel.Instance{
+ ID: iID,
+ Domain: domain,
+ Title: domain,
+ URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
+ }
+
+ insertQ := a.conn.
+ NewInsert().
+ Model(i)
+
+ if _, err := insertQ.Exec(ctx); err != nil {
+ return err
+ }
+ a.log.Infof("created instance instance %s with id %s", domain, i.ID)
+ return nil
+}
diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go
new file mode 100644
index 000000000..983b6b810
--- /dev/null
+++ b/internal/db/bundb/basic.go
@@ -0,0 +1,179 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "strings"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/uptrace/bun"
+)
+
+type basicDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewInsert().Model(i).Exec(ctx)
+ if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists
+ }
+ return err
+}
+
+func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewSelect().
+ Model(i).
+ Where("id = ?", id)
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := b.conn.NewSelect().Model(i)
+ for _, w := range where {
+
+ if w.Value == nil {
+ q = q.Where("? IS NULL", bun.Ident(w.Key))
+ } else {
+ if w.CaseInsensitive {
+ q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
+ } else {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+ }
+ }
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
+ q := b.conn.
+ NewSelect().
+ Model(i)
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewDelete().
+ Model(i).
+ Where("id = ?", id)
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := b.conn.
+ NewDelete().
+ Model(i)
+
+ for _, w := range where {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewUpdate().
+ Model(i).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
+ q := b.conn.NewUpdate().
+ Model(i).
+ Set("? = ?", bun.Safe(key), value).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
+ q := b.conn.NewUpdate().Model(i)
+
+ for _, w := range where {
+ if w.Value == nil {
+ q = q.Where("? IS NULL", bun.Ident(w.Key))
+ } else {
+ if w.CaseInsensitive {
+ q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
+ } else {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+ }
+ }
+
+ q = q.Set("? = ?", bun.Safe(key), value)
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx)
+ return err
+}
+
+func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
+ return b.conn.Ping()
+}
+
+func (b *basicDB) Stop(ctx context.Context) db.Error {
+ b.log.Info("closing db connection")
+ if err := b.conn.Close(); err != nil {
+ // only cancel if there's a problem closing the db
+ return err
+ }
+ return nil
+}
diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go
new file mode 100644
index 000000000..9189618c9
--- /dev/null
+++ b/internal/db/bundb/basic_test.go
@@ -0,0 +1,68 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type BasicTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *BasicTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *BasicTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *BasicTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *BasicTestSuite) TestGetAccountByID() {
+ testAccount := suite.testAccounts["local_account_1"]
+
+ a := &gtsmodel.Account{}
+ err := suite.db.GetByID(context.Background(), testAccount.ID, a)
+ suite.NoError(err)
+}
+
+func TestBasicTestSuite(t *testing.T) {
+ suite.Run(t, new(BasicTestSuite))
+}
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
new file mode 100644
index 000000000..49ed09cbd
--- /dev/null
+++ b/internal/db/bundb/bundb.go
@@ -0,0 +1,410 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "database/sql"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/jackc/pgx/v4"
+ "github.com/jackc/pgx/v4/stdlib"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/uptrace/bun"
+ "github.com/uptrace/bun/dialect/pgdialect"
+)
+
+const (
+ dbTypePostgres = "postgres"
+ dbTypeSqlite = "sqlite"
+)
+
+var registerTables []interface{} = []interface{}{
+ &gtsmodel.StatusToEmoji{},
+ &gtsmodel.StatusToTag{},
+}
+
+// bunDBService satisfies the DB interface
+type bunDBService struct {
+ db.Account
+ db.Admin
+ db.Basic
+ db.Domain
+ db.Instance
+ db.Media
+ db.Mention
+ db.Notification
+ db.Relationship
+ db.Session
+ db.Status
+ db.Timeline
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
+// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
+func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
+ var sqldb *sql.DB
+ var conn *bun.DB
+
+ // depending on the database type we're trying to create, we need to use a different driver...
+ switch strings.ToLower(c.DBConfig.Type) {
+ case dbTypePostgres:
+ // POSTGRES
+ opts, err := deriveBunDBPGOptions(c)
+ if err != nil {
+ return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
+ }
+ sqldb = stdlib.OpenDB(*opts)
+ conn = bun.NewDB(sqldb, pgdialect.New())
+ case dbTypeSqlite:
+ // SQLITE
+ // TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
+ default:
+ return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
+ }
+
+ // actually *begin* the connection so that we can tell if the db is there and listening
+ if err := conn.Ping(); err != nil {
+ return nil, fmt.Errorf("db connection error: %s", err)
+ }
+ log.Info("connected to database")
+
+ for _, t := range registerTables {
+ // https://bun.uptrace.dev/orm/many-to-many-relation/
+ conn.RegisterModel(t)
+ }
+
+ ps := &bunDBService{
+ Account: &accountDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Admin: &adminDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Basic: &basicDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Domain: &domainDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Instance: &instanceDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Media: &mediaDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Mention: &mentionDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Notification: &notificationDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Relationship: &relationshipDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Session: &sessionDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Status: &statusDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Timeline: &timelineDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ config: c,
+ conn: conn,
+ log: log,
+ }
+
+ // we can confidently return this useable service now
+ return ps, nil
+}
+
+/*
+ HANDY STUFF
+*/
+
+// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
+// with sensible defaults, or an error if it's not satisfied by the provided config.
+func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) {
+ if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
+ return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
+ }
+
+ // validate port
+ if c.DBConfig.Port == 0 {
+ return nil, errors.New("no port set")
+ }
+
+ // validate address
+ if c.DBConfig.Address == "" {
+ return nil, errors.New("no address set")
+ }
+
+ // validate username
+ if c.DBConfig.User == "" {
+ return nil, errors.New("no user set")
+ }
+
+ // validate that there's a password
+ if c.DBConfig.Password == "" {
+ return nil, errors.New("no password set")
+ }
+
+ // validate database
+ if c.DBConfig.Database == "" {
+ return nil, errors.New("no database set")
+ }
+
+ var tlsConfig *tls.Config
+ switch c.DBConfig.TLSMode {
+ case config.DBTLSModeDisable, config.DBTLSModeUnset:
+ break // nothing to do
+ case config.DBTLSModeEnable:
+ tlsConfig = &tls.Config{
+ InsecureSkipVerify: true,
+ }
+ case config.DBTLSModeRequire:
+ tlsConfig = &tls.Config{
+ InsecureSkipVerify: false,
+ ServerName: c.DBConfig.Address,
+ }
+ }
+
+ if tlsConfig != nil && c.DBConfig.TLSCACert != "" {
+ // load the system cert pool first -- we'll append the given CA cert to this
+ certPool, err := x509.SystemCertPool()
+ if err != nil {
+ return nil, fmt.Errorf("error fetching system CA cert pool: %s", err)
+ }
+
+ // open the file itself and make sure there's something in it
+ caCertBytes, err := os.ReadFile(c.DBConfig.TLSCACert)
+ if err != nil {
+ return nil, fmt.Errorf("error opening CA certificate at %s: %s", c.DBConfig.TLSCACert, err)
+ }
+ if len(caCertBytes) == 0 {
+ return nil, fmt.Errorf("ca cert at %s was empty", c.DBConfig.TLSCACert)
+ }
+
+ // make sure we have a PEM block
+ caPem, _ := pem.Decode(caCertBytes)
+ if caPem == nil {
+ return nil, fmt.Errorf("could not parse cert at %s into PEM", c.DBConfig.TLSCACert)
+ }
+
+ // parse the PEM block into the certificate
+ caCert, err := x509.ParseCertificate(caPem.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", c.DBConfig.TLSCACert, err)
+ }
+
+ // we're happy, add it to the existing pool and then use this pool in our tls config
+ certPool.AddCert(caCert)
+ tlsConfig.RootCAs = certPool
+ }
+
+ cfg, _ := pgx.ParseConfig("")
+ cfg.Host = c.DBConfig.Address
+ cfg.Port = uint16(c.DBConfig.Port)
+ cfg.User = c.DBConfig.User
+ cfg.Password = c.DBConfig.Password
+ cfg.TLSConfig = tlsConfig
+ cfg.Database = c.DBConfig.Database
+ cfg.PreferSimpleProtocol = true
+
+ return cfg, nil
+}
+
+/*
+ CONVERSION FUNCTIONS
+*/
+
+// TODO: move these to the type converter, it's bananas that they're here and not there
+
+func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
+ ogAccount := &gtsmodel.Account{}
+ if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ menchies := []*gtsmodel.Mention{}
+ for _, a := range targetAccounts {
+ // A mentioned account looks like "@test@example.org" or just "@test" for a local account
+ // -- we can guarantee this from the regex that targetAccounts should have been derived from.
+ // But we still need to do a bit of fiddling to get what we need here -- the username and domain (if given).
+
+ // 1. trim off the first @
+ t := strings.TrimPrefix(a, "@")
+
+ // 2. split the username and domain
+ s := strings.Split(t, "@")
+
+ // 3. if it's length 1 it's a local account, length 2 means remote, anything else means something is wrong
+ var local bool
+ switch len(s) {
+ case 1:
+ local = true
+ case 2:
+ local = false
+ default:
+ return nil, fmt.Errorf("mentioned account format '%s' was not valid", a)
+ }
+
+ var username, domain string
+ username = s[0]
+ if !local {
+ domain = s[1]
+ }
+
+ // 4. check we now have a proper username and domain
+ if username == "" || (!local && domain == "") {
+ return nil, fmt.Errorf("username or domain for '%s' was nil", a)
+ }
+
+ // okay we're good now, we can start pulling accounts out of the database
+ mentionedAccount := &gtsmodel.Account{}
+ var err error
+
+ // match username + account, case insensitive
+ if local {
+ // local user -- should have a null domain
+ err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx)
+ } else {
+ // remote user -- should have domain defined
+ err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx)
+ }
+
+ if err != nil {
+ if err == sql.ErrNoRows {
+ // no result found for this username/domain so just don't include it as a mencho and carry on about our business
+ ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
+ continue
+ }
+ // a serious error has happened so bail
+ return nil, fmt.Errorf("error getting account with username '%s' and domain '%s': %s", username, domain, err)
+ }
+
+ // id, createdAt and updatedAt will be populated by the db, so we have everything we need!
+ menchies = append(menchies, &gtsmodel.Mention{
+ StatusID: statusID,
+ OriginAccountID: ogAccount.ID,
+ OriginAccountURI: ogAccount.URI,
+ TargetAccountID: mentionedAccount.ID,
+ NameString: a,
+ TargetAccountURI: mentionedAccount.URI,
+ TargetAccountURL: mentionedAccount.URL,
+ OriginAccount: mentionedAccount,
+ })
+ }
+ return menchies, nil
+}
+
+func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
+ newTags := []*gtsmodel.Tag{}
+ for _, t := range tags {
+ tag := &gtsmodel.Tag{}
+ // we can use selectorinsert here to create the new tag if it doesn't exist already
+ // inserted will be true if this is a new tag we just created
+ if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
+ if err == sql.ErrNoRows {
+ // tag doesn't exist yet so populate it
+ newID, err := id.NewRandomULID()
+ if err != nil {
+ return nil, err
+ }
+ tag.ID = newID
+ tag.URL = fmt.Sprintf("%s://%s/tags/%s", ps.config.Protocol, ps.config.Host, t)
+ tag.Name = t
+ tag.FirstSeenFromAccountID = originAccountID
+ tag.CreatedAt = time.Now()
+ tag.UpdatedAt = time.Now()
+ tag.Useable = true
+ tag.Listable = true
+ } else {
+ return nil, fmt.Errorf("error getting tag with name %s: %s", t, err)
+ }
+ }
+
+ // bail already if the tag isn't useable
+ if !tag.Useable {
+ continue
+ }
+ tag.LastStatusAt = time.Now()
+ newTags = append(newTags, tag)
+ }
+ return newTags, nil
+}
+
+func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
+ newEmojis := []*gtsmodel.Emoji{}
+ for _, e := range emojis {
+ emoji := &gtsmodel.Emoji{}
+ err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ // no result found for this username/domain so just don't include it as an emoji and carry on about our business
+ ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
+ continue
+ }
+ // a serious error has happened so bail
+ return nil, fmt.Errorf("error getting emoji with shortcode %s: %s", e, err)
+ }
+ newEmojis = append(newEmojis, emoji)
+ }
+ return newEmojis, nil
+}
diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go
new file mode 100644
index 000000000..b789375af
--- /dev/null
+++ b/internal/db/bundb/bundb_test.go
@@ -0,0 +1,47 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb_test
+
+import (
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+)
+
+type BunDBStandardTestSuite struct {
+ // standard suite interfaces
+ suite.Suite
+ config *config.Config
+ db db.DB
+ log *logrus.Logger
+
+ // standard suite models
+ testTokens map[string]*oauth.Token
+ testClients map[string]*oauth.Client
+ testApplications map[string]*gtsmodel.Application
+ testUsers map[string]*gtsmodel.User
+ testAccounts map[string]*gtsmodel.Account
+ testAttachments map[string]*gtsmodel.MediaAttachment
+ testStatuses map[string]*gtsmodel.Status
+ testTags map[string]*gtsmodel.Tag
+ testMentions map[string]*gtsmodel.Mention
+}
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
new file mode 100644
index 000000000..6aa2b8ffe
--- /dev/null
+++ b/internal/db/bundb/domain.go
@@ -0,0 +1,81 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "net/url"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
+)
+
+type domainDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
+ if domain == "" {
+ return false, nil
+ }
+
+ q := d.conn.
+ NewSelect().
+ Model(&gtsmodel.DomainBlock{}).
+ Where("LOWER(domain) = LOWER(?)", domain).
+ Limit(1)
+
+ return exists(ctx, q)
+}
+
+func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
+ // filter out any doubles
+ uniqueDomains := util.UniqueStrings(domains)
+
+ for _, domain := range uniqueDomains {
+ if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil {
+ return false, err
+ } else if blocked {
+ return blocked, nil
+ }
+ }
+
+ // no blocks found
+ return false, nil
+}
+
+func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) {
+ domain := uri.Hostname()
+ return d.IsDomainBlocked(ctx, domain)
+}
+
+func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) {
+ domains := []string{}
+ for _, uri := range uris {
+ domains = append(domains, uri.Hostname())
+ }
+
+ return d.AreDomainsBlocked(ctx, domains)
+}
diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go
new file mode 100644
index 000000000..f9364346e
--- /dev/null
+++ b/internal/db/bundb/instance.go
@@ -0,0 +1,118 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type instanceDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Account{})
+
+ if domain == i.config.Host {
+ // if the domain is *this* domain, just count where the domain field is null
+ q = q.Where("? IS NULL", bun.Ident("domain"))
+ } else {
+ q = q.Where("domain = ?", domain)
+ }
+
+ // don't count the instance account or suspended users
+ q = q.
+ Where("username != ?", domain).
+ Where("? IS NULL", bun.Ident("suspended_at"))
+
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
+}
+
+func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Status{})
+
+ if domain == i.config.Host {
+ // if the domain is *this* domain, just count where local is true
+ q = q.Where("local = ?", true)
+ } else {
+ // join on the domain of the account
+ q = q.Join("JOIN accounts AS account ON account.id = status.account_id").
+ Where("account.domain = ?", domain)
+ }
+
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
+}
+
+func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Instance{})
+
+ if domain == i.config.Host {
+ // if the domain is *this* domain, just count other instances it knows about
+ // exclude domains that are blocked
+ q = q.Where("domain != ?", domain).Where("? IS NULL", bun.Ident("suspended_at"))
+ } else {
+ // TODO: implement federated domain counting properly for remote domains
+ return 0, nil
+ }
+
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
+}
+
+func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
+ i.log.Debug("GetAccountsForInstance")
+
+ accounts := []*gtsmodel.Account{}
+
+ q := i.conn.NewSelect().
+ Model(&accounts).
+ Where("domain = ?", domain).
+ Order("id DESC")
+
+ if maxID != "" {
+ q = q.Where("id < ?", maxID)
+ }
+
+ if limit > 0 {
+ q = q.Limit(limit)
+ }
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return accounts, err
+}
diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go
new file mode 100644
index 000000000..04e55ca62
--- /dev/null
+++ b/internal/db/bundb/media.go
@@ -0,0 +1,53 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type mediaDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {
+ return m.conn.
+ NewSelect().
+ Model(i).
+ Relation("Account")
+}
+
+func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) {
+ attachment := &gtsmodel.MediaAttachment{}
+
+ q := m.newMediaQ(attachment).
+ Where("media_attachment.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return attachment, err
+}
diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go
new file mode 100644
index 000000000..a444f9b5f
--- /dev/null
+++ b/internal/db/bundb/mention.go
@@ -0,0 +1,108 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type mentionDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+ cache cache.Cache
+}
+
+func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) {
+ if m.cache == nil {
+ m.cache = cache.New()
+ }
+
+ if err := m.cache.Store(id, mention); err != nil {
+ m.log.Panicf("mentionDB: error storing in cache: %s", err)
+ }
+}
+
+func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
+ if m.cache == nil {
+ m.cache = cache.New()
+ return nil, false
+ }
+
+ mI, err := m.cache.Fetch(id)
+ if err != nil || mI == nil {
+ return nil, false
+ }
+
+ mention, ok := mI.(*gtsmodel.Mention)
+ if !ok {
+ m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id)
+ }
+
+ return mention, true
+}
+
+func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
+ return m.conn.
+ NewSelect().
+ Model(i).
+ Relation("Status").
+ Relation("OriginAccount").
+ Relation("TargetAccount")
+}
+
+func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
+ if mention, cached := m.mentionCached(id); cached {
+ return mention, nil
+ }
+
+ mention := &gtsmodel.Mention{}
+
+ q := m.newMentionQ(mention).
+ Where("mention.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err == nil && mention != nil {
+ m.cacheMention(id, mention)
+ }
+
+ return mention, err
+}
+
+func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
+ mentions := []*gtsmodel.Mention{}
+
+ for _, i := range ids {
+ mention, err := m.GetMention(ctx, i)
+ if err != nil {
+ return nil, processErrorResponse(err)
+ }
+ mentions = append(mentions, mention)
+ }
+
+ return mentions, nil
+}
diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go
new file mode 100644
index 000000000..1c30837ec
--- /dev/null
+++ b/internal/db/bundb/notification.go
@@ -0,0 +1,136 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type notificationDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+ cache cache.Cache
+}
+
+func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) {
+ if n.cache == nil {
+ n.cache = cache.New()
+ }
+
+ if err := n.cache.Store(id, notification); err != nil {
+ n.log.Panicf("notificationDB: error storing in cache: %s", err)
+ }
+}
+
+func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) {
+ if n.cache == nil {
+ n.cache = cache.New()
+ return nil, false
+ }
+
+ nI, err := n.cache.Fetch(id)
+ if err != nil || nI == nil {
+ return nil, false
+ }
+
+ notification, ok := nI.(*gtsmodel.Notification)
+ if !ok {
+ n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id)
+ }
+
+ return notification, true
+}
+
+func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
+ return n.conn.
+ NewSelect().
+ Model(i).
+ Relation("OriginAccount").
+ Relation("TargetAccount").
+ Relation("Status")
+}
+
+func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
+ if notification, cached := n.notificationCached(id); cached {
+ return notification, nil
+ }
+
+ notification := &gtsmodel.Notification{}
+
+ q := n.newNotificationQ(notification).
+ Where("notification.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err == nil && notification != nil {
+ n.cacheNotification(id, notification)
+ }
+
+ return notification, err
+}
+
+func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
+ // begin by selecting just the IDs
+ notifIDs := []*gtsmodel.Notification{}
+ q := n.conn.
+ NewSelect().
+ Model(&notifIDs).
+ Column("id").
+ Where("target_account_id = ?", accountID).
+ Order("id DESC")
+
+ if maxID != "" {
+ q = q.Where("id < ?", maxID)
+ }
+
+ if sinceID != "" {
+ q = q.Where("id > ?", sinceID)
+ }
+
+ if limit != 0 {
+ q = q.Limit(limit)
+ }
+
+ err := processErrorResponse(q.Scan(ctx))
+ if err != nil {
+ return nil, err
+ }
+
+ // now we have the IDs, select the notifs one by one
+ // reason for this is that for each notif, we can instead get it from our cache if it's cached
+ notifications := []*gtsmodel.Notification{}
+ for _, notifID := range notifIDs {
+ notif, err := n.GetNotification(ctx, notifID.ID)
+ errP := processErrorResponse(err)
+ if errP != nil {
+ return nil, errP
+ }
+ notifications = append(notifications, notif)
+ }
+
+ return notifications, nil
+}
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
new file mode 100644
index 000000000..ccc604baf
--- /dev/null
+++ b/internal/db/bundb/relationship.go
@@ -0,0 +1,328 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type relationshipDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {
+ return r.conn.
+ NewSelect().
+ Model(block).
+ Relation("Account").
+ Relation("TargetAccount")
+}
+
+func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
+ return r.conn.
+ NewSelect().
+ Model(follow).
+ Relation("Account").
+ Relation("TargetAccount")
+}
+
+func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
+ q := r.conn.
+ NewSelect().
+ Model(&gtsmodel.Block{}).
+ Where("account_id = ?", account1).
+ Where("target_account_id = ?", account2).
+ Limit(1)
+
+ if eitherDirection {
+ q = q.
+ WhereOr("target_account_id = ?", account1).
+ Where("account_id = ?", account2)
+ }
+
+ return exists(ctx, q)
+}
+
+func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
+ block := &gtsmodel.Block{}
+
+ q := r.newBlockQ(block).
+ Where("block.account_id = ?", account1).
+ Where("block.target_account_id = ?", account2)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return block, err
+}
+
+func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
+ rel := &gtsmodel.Relationship{
+ ID: targetAccount,
+ }
+
+ // check if the requesting account follows the target account
+ follow := &gtsmodel.Follow{}
+ if err := r.conn.
+ NewSelect().
+ Model(follow).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Scan(ctx); err != nil {
+ if err != sql.ErrNoRows {
+ // a proper error
+ return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
+ }
+ // no follow exists so these are all false
+ rel.Following = false
+ rel.ShowingReblogs = false
+ rel.Notifying = false
+ } else {
+ // follow exists so we can fill these fields out...
+ rel.Following = true
+ rel.ShowingReblogs = follow.ShowReblogs
+ rel.Notifying = follow.Notify
+ }
+
+ // check if the target account follows the requesting account
+ count, err := r.conn.
+ NewSelect().
+ Model(&gtsmodel.Follow{}).
+ Where("account_id = ?", targetAccount).
+ Where("target_account_id = ?", requestingAccount).
+ Limit(1).
+ Count(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
+ }
+ rel.FollowedBy = count > 0
+
+ // check if the requesting account blocks the target account
+ count, err = r.conn.NewSelect().
+ Model(&gtsmodel.Block{}).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Count(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
+ }
+ rel.Blocking = count > 0
+
+ // check if the target account blocks the requesting account
+ count, err = r.conn.
+ NewSelect().
+ Model(&gtsmodel.Block{}).
+ Where("account_id = ?", targetAccount).
+ Where("target_account_id = ?", requestingAccount).
+ Limit(1).
+ Count(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ }
+ rel.BlockedBy = count > 0
+
+ // check if there's a pending following request from requesting account to target account
+ count, err = r.conn.
+ NewSelect().
+ Model(&gtsmodel.FollowRequest{}).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Count(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ }
+ rel.Requested = count > 0
+
+ return rel, nil
+}
+
+func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
+ if sourceAccount == nil || targetAccount == nil {
+ return false, nil
+ }
+
+ q := r.conn.
+ NewSelect().
+ Model(&gtsmodel.Follow{}).
+ Where("account_id = ?", sourceAccount.ID).
+ Where("target_account_id = ?", targetAccount.ID).
+ Limit(1)
+
+ return exists(ctx, q)
+}
+
+func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
+ if sourceAccount == nil || targetAccount == nil {
+ return false, nil
+ }
+
+ q := r.conn.
+ NewSelect().
+ Model(&gtsmodel.FollowRequest{}).
+ Where("account_id = ?", sourceAccount.ID).
+ Where("target_account_id = ?", targetAccount.ID)
+
+ return exists(ctx, q)
+}
+
+func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
+ if account1 == nil || account2 == nil {
+ return false, nil
+ }
+
+ // make sure account 1 follows account 2
+ f1, err := r.IsFollowing(ctx, account1, account2)
+ if err != nil {
+ return false, processErrorResponse(err)
+ }
+
+ // make sure account 2 follows account 1
+ f2, err := r.IsFollowing(ctx, account2, account1)
+ if err != nil {
+ return false, processErrorResponse(err)
+ }
+
+ return f1 && f2, nil
+}
+
+func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
+ // make sure the original follow request exists
+ fr := &gtsmodel.FollowRequest{}
+ if err := r.conn.
+ NewSelect().
+ Model(fr).
+ Where("account_id = ?", originAccountID).
+ Where("target_account_id = ?", targetAccountID).
+ Scan(ctx); err != nil {
+ return nil, processErrorResponse(err)
+ }
+
+ // create a new follow to 'replace' the request with
+ follow := &gtsmodel.Follow{
+ ID: fr.ID,
+ AccountID: originAccountID,
+ TargetAccountID: targetAccountID,
+ URI: fr.URI,
+ }
+
+ // if the follow already exists, just update the URI -- we don't need to do anything else
+ if _, err := r.conn.
+ NewInsert().
+ Model(follow).
+ On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
+ Exec(ctx); err != nil {
+ return nil, processErrorResponse(err)
+ }
+
+ // now remove the follow request
+ if _, err := r.conn.
+ NewDelete().
+ Model(&gtsmodel.FollowRequest{}).
+ Where("account_id = ?", originAccountID).
+ Where("target_account_id = ?", targetAccountID).
+ Exec(ctx); err != nil {
+ return nil, processErrorResponse(err)
+ }
+
+ return follow, nil
+}
+
+func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
+ followRequests := []*gtsmodel.FollowRequest{}
+
+ q := r.newFollowQ(&followRequests).
+ Where("target_account_id = ?", accountID)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return followRequests, err
+}
+
+func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) {
+ follows := []*gtsmodel.Follow{}
+
+ q := r.newFollowQ(&follows).
+ Where("account_id = ?", accountID)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return follows, err
+}
+
+func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
+ return r.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Follow{}).
+ Where("account_id = ?", accountID).
+ Count(ctx)
+}
+
+func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
+
+ follows := []*gtsmodel.Follow{}
+
+ q := r.conn.
+ NewSelect().
+ Model(&follows)
+
+ if localOnly {
+ // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
+ whereGroup := func(q *bun.SelectQuery) *bun.SelectQuery {
+ q = q.
+ WhereOr("? IS NULL", bun.Ident("a.domain")).
+ WhereOr("a.domain = ?", "")
+ return q
+ }
+
+ q = q.ColumnExpr("follow.*").
+ Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
+ Where("follow.target_account_id = ?", accountID).
+ WhereGroup(" AND ", whereGroup)
+ } else {
+ q = q.Where("target_account_id = ?", accountID)
+ }
+
+ if err := q.Scan(ctx); err != nil {
+ if err == sql.ErrNoRows {
+ return follows, nil
+ }
+ return nil, processErrorResponse(err)
+ }
+ return follows, nil
+}
+
+func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
+ return r.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Follow{}).
+ Where("target_account_id = ?", accountID).
+ Count(ctx)
+}
diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go
new file mode 100644
index 000000000..87e20673d
--- /dev/null
+++ b/internal/db/bundb/session.go
@@ -0,0 +1,85 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "crypto/rand"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/uptrace/bun"
+)
+
+type sessionDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
+ rs := new(gtsmodel.RouterSession)
+
+ q := s.conn.
+ NewSelect().
+ Model(rs).
+ Limit(1)
+
+ _, err := q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return rs, err
+}
+
+func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
+ auth := make([]byte, 32)
+ crypt := make([]byte, 32)
+
+ if _, err := rand.Read(auth); err != nil {
+ return nil, err
+ }
+ if _, err := rand.Read(crypt); err != nil {
+ return nil, err
+ }
+
+ rid, err := id.NewULID()
+ if err != nil {
+ return nil, err
+ }
+
+ rs := &gtsmodel.RouterSession{
+ ID: rid,
+ Auth: auth,
+ Crypt: crypt,
+ }
+
+ q := s.conn.
+ NewInsert().
+ Model(rs)
+
+ _, err = q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return rs, err
+}
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
new file mode 100644
index 000000000..da8d8ca41
--- /dev/null
+++ b/internal/db/bundb/status.go
@@ -0,0 +1,375 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "container/list"
+ "context"
+ "errors"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type statusDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+ cache cache.Cache
+}
+
+func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
+ if s.cache == nil {
+ s.cache = cache.New()
+ }
+
+ if err := s.cache.Store(id, status); err != nil {
+ s.log.Panicf("statusDB: error storing in cache: %s", err)
+ }
+}
+
+func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
+ if s.cache == nil {
+ s.cache = cache.New()
+ return nil, false
+ }
+
+ sI, err := s.cache.Fetch(id)
+ if err != nil || sI == nil {
+ return nil, false
+ }
+
+ status, ok := sI.(*gtsmodel.Status)
+ if !ok {
+ s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
+ }
+
+ return status, true
+}
+
+func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
+ return s.conn.
+ NewSelect().
+ Model(status).
+ Relation("Attachments").
+ Relation("Tags").
+ Relation("Mentions").
+ Relation("Emojis").
+ Relation("Account").
+ Relation("InReplyToAccount").
+ Relation("BoostOfAccount").
+ Relation("CreatedWithApplication")
+}
+
+func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
+ if status.InReplyToID != "" && status.InReplyTo == nil {
+ if inReplyTo, cached := s.statusCached(status.InReplyToID); cached {
+ status.InReplyTo = inReplyTo
+ } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
+ status.InReplyTo = inReplyTo
+ }
+ }
+
+ if status.BoostOfID != "" && status.BoostOf == nil {
+ if boostOf, cached := s.statusCached(status.BoostOfID); cached {
+ status.BoostOf = boostOf
+ } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
+ status.BoostOf = boostOf
+ }
+ }
+
+ return status
+}
+
+func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
+ return s.conn.
+ NewSelect().
+ Model(faves).
+ Relation("Account").
+ Relation("TargetAccount").
+ Relation("Status")
+}
+
+func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(id); cached {
+ return status, nil
+ }
+
+ status := new(gtsmodel.Status)
+
+ q := s.newStatusQ(status).
+ Where("status.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(id, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(uri); cached {
+ return status, nil
+ }
+
+ status := &gtsmodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("LOWER(status.uri) = LOWER(?)", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(uri, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(uri); cached {
+ return status, nil
+ }
+
+ status := &gtsmodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("LOWER(status.url) = LOWER(?)", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(uri, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
+ transaction := func(ctx context.Context, tx bun.Tx) error {
+ // create links between this status and any emojis it uses
+ for _, i := range status.EmojiIDs {
+ if _, err := tx.NewInsert().Model(&gtsmodel.StatusToEmoji{
+ StatusID: status.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // create links between this status and any tags it uses
+ for _, i := range status.TagIDs {
+ if _, err := tx.NewInsert().Model(&gtsmodel.StatusToTag{
+ StatusID: status.ID,
+ TagID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // change the status ID of the media attachments to the new status
+ for _, a := range status.Attachments {
+ a.StatusID = status.ID
+ a.UpdatedAt = time.Now()
+ if _, err := s.conn.NewUpdate().Model(a).
+ Where("id = ?", a.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ _, err := tx.NewInsert().Model(status).Exec(ctx)
+ return err
+ }
+
+ return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
+}
+
+func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
+ parents := []*gtsmodel.Status{}
+ s.statusParent(ctx, status, &parents, onlyDirect)
+
+ return parents, nil
+}
+
+func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
+ if status.InReplyToID == "" {
+ return
+ }
+
+ parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID)
+ if err == nil {
+ *foundStatuses = append(*foundStatuses, parentStatus)
+ }
+
+ if onlyDirect {
+ return
+ }
+
+ s.statusParent(ctx, parentStatus, foundStatuses, false)
+}
+
+func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
+ foundStatuses := &list.List{}
+ foundStatuses.PushFront(status)
+ s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID)
+
+ children := []*gtsmodel.Status{}
+ for e := foundStatuses.Front(); e != nil; e = e.Next() {
+ entry, ok := e.Value.(*gtsmodel.Status)
+ if !ok {
+ panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
+ }
+
+ // only append children, not the overall parent status
+ if entry.ID != status.ID {
+ children = append(children, entry)
+ }
+ }
+
+ return children, nil
+}
+
+func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
+ immediateChildren := []*gtsmodel.Status{}
+
+ q := s.conn.
+ NewSelect().
+ Model(&immediateChildren).
+ Where("in_reply_to_id = ?", status.ID)
+ if minID != "" {
+ q = q.Where("status.id > ?", minID)
+ }
+
+ if err := q.Scan(ctx); err != nil {
+ return
+ }
+
+ for _, child := range immediateChildren {
+ insertLoop:
+ for e := foundStatuses.Front(); e != nil; e = e.Next() {
+ entry, ok := e.Value.(*gtsmodel.Status)
+ if !ok {
+ panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
+ }
+
+ if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
+ foundStatuses.InsertAfter(child, e)
+ break insertLoop
+ }
+ }
+
+ // only do one loop if we only want direct children
+ if onlyDirect {
+ return
+ }
+ s.statusChildren(ctx, child, foundStatuses, false, minID)
+ }
+}
+
+func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.StatusFave{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.Status{}).
+ Where("boost_of_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.StatusMute{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(&gtsmodel.StatusBookmark{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
+ faves := []*gtsmodel.StatusFave{}
+
+ q := s.newFaveQ(&faves).
+ Where("status_id = ?", status.ID)
+
+ err := processErrorResponse(q.Scan(ctx))
+ return faves, err
+}
+
+func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
+ reblogs := []*gtsmodel.Status{}
+
+ q := s.newStatusQ(&reblogs).
+ Where("boost_of_id = ?", status.ID)
+
+ err := processErrorResponse(q.Scan(ctx))
+ return reblogs, err
+}
diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go
new file mode 100644
index 000000000..513000577
--- /dev/null
+++ b/internal/db/bundb/status_test.go
@@ -0,0 +1,136 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type StatusTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *StatusTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *StatusTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *StatusTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *StatusTestSuite) TestGetStatusByID() {
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID)
+ if err != nil {
+ fmt.Println(err.Error())
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.Nil(status.BoostOf)
+ suite.Nil(status.BoostOfAccount)
+ suite.Nil(status.InReplyTo)
+ suite.Nil(status.InReplyToAccount)
+}
+
+func (suite *StatusTestSuite) TestGetStatusByURI() {
+ status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.Nil(status.BoostOf)
+ suite.Nil(status.BoostOfAccount)
+ suite.Nil(status.InReplyTo)
+ suite.Nil(status.InReplyToAccount)
+}
+
+func (suite *StatusTestSuite) TestGetStatusWithExtras() {
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["admin_account_status_1"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.NotEmpty(status.Tags)
+ suite.NotEmpty(status.Attachments)
+ suite.NotEmpty(status.Emojis)
+}
+
+func (suite *StatusTestSuite) TestGetStatusWithMention() {
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_2_status_5"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.NotEmpty(status.Mentions)
+ suite.NotEmpty(status.MentionIDs)
+ suite.NotNil(status.InReplyTo)
+ suite.NotNil(status.InReplyToAccount)
+}
+
+func (suite *StatusTestSuite) TestGetStatusTwice() {
+ before1 := time.Now()
+ _, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
+ suite.NoError(err)
+ after1 := time.Now()
+ duration1 := after1.Sub(before1)
+ fmt.Println(duration1.Milliseconds())
+
+ before2 := time.Now()
+ _, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
+ suite.NoError(err)
+ after2 := time.Now()
+ duration2 := after2.Sub(before2)
+ fmt.Println(duration2.Milliseconds())
+
+ // second retrieval should be several orders faster since it will be cached now
+ suite.Less(duration2, duration1)
+}
+
+func TestStatusTestSuite(t *testing.T) {
+ suite.Run(t, new(StatusTestSuite))
+}
diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go
new file mode 100644
index 000000000..b62ad4c50
--- /dev/null
+++ b/internal/db/bundb/timeline.go
@@ -0,0 +1,199 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "database/sql"
+ "sort"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type timelineDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
+ statuses := []*gtsmodel.Status{}
+ q := t.conn.
+ NewSelect().
+ Model(&statuses)
+
+ q = q.ColumnExpr("status.*").
+ // Find out who accountID follows.
+ Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id").
+ // Sort by highest ID (newest) to lowest ID (oldest)
+ Order("status.id DESC")
+
+ if maxID != "" {
+ // return only statuses LOWER (ie., older) than maxID
+ q = q.Where("status.id < ?", maxID)
+ }
+
+ if sinceID != "" {
+ // return only statuses HIGHER (ie., newer) than sinceID
+ q = q.Where("status.id > ?", sinceID)
+ }
+
+ if minID != "" {
+ // return only statuses HIGHER (ie., newer) than minID
+ q = q.Where("status.id > ?", minID)
+ }
+
+ if local {
+ // return only statuses posted by local account havers
+ q = q.Where("status.local = ?", local)
+ }
+
+ if limit > 0 {
+ // limit amount of statuses returned
+ q = q.Limit(limit)
+ }
+
+ // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
+ // OR statuses posted by accountID itself (since a user should be able to see their own statuses).
+ //
+ // This is equivalent to something like WHERE ... AND (... OR ...)
+ // See: https://bun.uptrace.dev/guide/queries.html#select
+ whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
+ return q.
+ WhereOr("f.account_id = ?", accountID).
+ WhereOr("status.account_id = ?", accountID)
+ }
+
+ q = q.WhereGroup(" AND ", whereGroup)
+
+ return statuses, processErrorResponse(q.Scan(ctx))
+}
+
+func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
+ statuses := []*gtsmodel.Status{}
+
+ q := t.conn.
+ NewSelect().
+ Model(&statuses).
+ Where("visibility = ?", gtsmodel.VisibilityPublic).
+ Where("? IS NULL", bun.Ident("in_reply_to_id")).
+ Where("? IS NULL", bun.Ident("in_reply_to_uri")).
+ Where("? IS NULL", bun.Ident("boost_of_id")).
+ Order("status.id DESC")
+
+ if maxID != "" {
+ q = q.Where("status.id < ?", maxID)
+ }
+
+ if sinceID != "" {
+ q = q.Where("status.id > ?", sinceID)
+ }
+
+ if minID != "" {
+ q = q.Where("status.id > ?", minID)
+ }
+
+ if local {
+ q = q.Where("status.local = ?", local)
+ }
+
+ if limit > 0 {
+ q = q.Limit(limit)
+ }
+
+ return statuses, processErrorResponse(q.Scan(ctx))
+}
+
+// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
+// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
+func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
+
+ faves := []*gtsmodel.StatusFave{}
+
+ fq := t.conn.
+ NewSelect().
+ Model(&faves).
+ Where("account_id = ?", accountID).
+ Order("id DESC")
+
+ if maxID != "" {
+ fq = fq.Where("id < ?", maxID)
+ }
+
+ if minID != "" {
+ fq = fq.Where("id > ?", minID)
+ }
+
+ if limit > 0 {
+ fq = fq.Limit(limit)
+ }
+
+ err := fq.Scan(ctx)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, "", "", db.ErrNoEntries
+ }
+ return nil, "", "", err
+ }
+
+ if len(faves) == 0 {
+ return nil, "", "", db.ErrNoEntries
+ }
+
+ // map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
+ statusesFavesMap := map[string]string{}
+
+ in := []string{}
+ for _, f := range faves {
+ statusesFavesMap[f.StatusID] = f.ID
+ in = append(in, f.StatusID)
+ }
+
+ statuses := []*gtsmodel.Status{}
+ err = t.conn.
+ NewSelect().
+ Model(&statuses).
+ Where("id IN (?)", bun.In(in)).
+ Scan(ctx)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, "", "", db.ErrNoEntries
+ }
+ return nil, "", "", err
+ }
+
+ if len(statuses) == 0 {
+ return nil, "", "", db.ErrNoEntries
+ }
+
+ // arrange statuses by fave ID
+ sort.Slice(statuses, func(i int, j int) bool {
+ statusI := statuses[i]
+ statusJ := statuses[j]
+ return statusesFavesMap[statusI.ID] < statusesFavesMap[statusJ.ID]
+ })
+
+ nextMaxID := faves[len(faves)-1].ID
+ prevMinID := faves[0].ID
+ return statuses, nextMaxID, prevMinID, nil
+}
diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go
new file mode 100644
index 000000000..115d18de2
--- /dev/null
+++ b/internal/db/bundb/util.go
@@ -0,0 +1,78 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "strings"
+
+ "database/sql"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/uptrace/bun"
+)
+
+// processErrorResponse parses the given error and returns an appropriate DBError.
+func processErrorResponse(err error) db.Error {
+ switch err {
+ case nil:
+ return nil
+ case sql.ErrNoRows:
+ return db.ErrNoEntries
+ default:
+ if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists
+ }
+ return err
+ }
+}
+
+func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
+ count, err := q.Count(ctx)
+
+ exists := count != 0
+
+ err = processErrorResponse(err)
+
+ if err != nil {
+ if err == db.ErrNoEntries {
+ return false, nil
+ }
+ return false, err
+ }
+
+ return exists, nil
+}
+
+func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
+ count, err := q.Count(ctx)
+
+ notExists := count == 0
+
+ err = processErrorResponse(err)
+
+ if err != nil {
+ if err == db.ErrNoEntries {
+ return true, nil
+ }
+ return false, err
+ }
+
+ return notExists, nil
+}