From 2dc9fc1626507bb54417fc4a1920b847cafb27a2 Mon Sep 17 00:00:00 2001
From: tobi <31960611+tsmethurst@users.noreply.github.com>
Date: Wed, 25 Aug 2021 15:34:33 +0200
Subject: Pg to bun (#148)
* start moving to bun
* changing more stuff
* more
* and yet more
* tests passing
* seems stable now
* more big changes
* small fix
* little fixes
---
internal/db/account.go | 26 ++-
internal/db/admin.go | 11 +-
internal/db/basic.go | 31 +--
internal/db/bundb/account.go | 291 ++++++++++++++++++++++++++
internal/db/bundb/account_test.go | 86 ++++++++
internal/db/bundb/admin.go | 272 ++++++++++++++++++++++++
internal/db/bundb/basic.go | 179 ++++++++++++++++
internal/db/bundb/basic_test.go | 68 ++++++
internal/db/bundb/bundb.go | 410 +++++++++++++++++++++++++++++++++++++
internal/db/bundb/bundb_test.go | 47 +++++
internal/db/bundb/domain.go | 81 ++++++++
internal/db/bundb/instance.go | 118 +++++++++++
internal/db/bundb/media.go | 53 +++++
internal/db/bundb/mention.go | 108 ++++++++++
internal/db/bundb/notification.go | 136 ++++++++++++
internal/db/bundb/relationship.go | 328 +++++++++++++++++++++++++++++
internal/db/bundb/session.go | 85 ++++++++
internal/db/bundb/status.go | 375 ++++++++++++++++++++++++++++++++++
internal/db/bundb/status_test.go | 136 ++++++++++++
internal/db/bundb/timeline.go | 199 ++++++++++++++++++
internal/db/bundb/util.go | 78 +++++++
internal/db/db.go | 9 +-
internal/db/domain.go | 13 +-
internal/db/instance.go | 14 +-
internal/db/media.go | 8 +-
internal/db/mention.go | 10 +-
internal/db/notification.go | 10 +-
internal/db/pg/account.go | 256 -----------------------
internal/db/pg/account_test.go | 70 -------
internal/db/pg/admin.go | 235 ---------------------
internal/db/pg/basic.go | 205 -------------------
internal/db/pg/domain.go | 83 --------
internal/db/pg/instance.go | 112 ----------
internal/db/pg/media.go | 53 -----
internal/db/pg/mention.go | 108 ----------
internal/db/pg/notification.go | 135 ------------
internal/db/pg/pg.go | 420 --------------------------------------
internal/db/pg/pg_test.go | 47 -----
internal/db/pg/relationship.go | 276 -------------------------
internal/db/pg/status.go | 318 -----------------------------
internal/db/pg/status_test.go | 134 ------------
internal/db/pg/timeline.go | 210 -------------------
internal/db/pg/util.go | 25 ---
internal/db/relationship.go | 30 +--
internal/db/session.go | 31 +++
internal/db/status.go | 36 ++--
internal/db/timeline.go | 12 +-
47 files changed, 3201 insertions(+), 2777 deletions(-)
create mode 100644 internal/db/bundb/account.go
create mode 100644 internal/db/bundb/account_test.go
create mode 100644 internal/db/bundb/admin.go
create mode 100644 internal/db/bundb/basic.go
create mode 100644 internal/db/bundb/basic_test.go
create mode 100644 internal/db/bundb/bundb.go
create mode 100644 internal/db/bundb/bundb_test.go
create mode 100644 internal/db/bundb/domain.go
create mode 100644 internal/db/bundb/instance.go
create mode 100644 internal/db/bundb/media.go
create mode 100644 internal/db/bundb/mention.go
create mode 100644 internal/db/bundb/notification.go
create mode 100644 internal/db/bundb/relationship.go
create mode 100644 internal/db/bundb/session.go
create mode 100644 internal/db/bundb/status.go
create mode 100644 internal/db/bundb/status_test.go
create mode 100644 internal/db/bundb/timeline.go
create mode 100644 internal/db/bundb/util.go
delete mode 100644 internal/db/pg/account.go
delete mode 100644 internal/db/pg/account_test.go
delete mode 100644 internal/db/pg/admin.go
delete mode 100644 internal/db/pg/basic.go
delete mode 100644 internal/db/pg/domain.go
delete mode 100644 internal/db/pg/instance.go
delete mode 100644 internal/db/pg/media.go
delete mode 100644 internal/db/pg/mention.go
delete mode 100644 internal/db/pg/notification.go
delete mode 100644 internal/db/pg/pg.go
delete mode 100644 internal/db/pg/pg_test.go
delete mode 100644 internal/db/pg/relationship.go
delete mode 100644 internal/db/pg/status.go
delete mode 100644 internal/db/pg/status_test.go
delete mode 100644 internal/db/pg/timeline.go
delete mode 100644 internal/db/pg/util.go
create mode 100644 internal/db/session.go
(limited to 'internal/db')
diff --git a/internal/db/account.go b/internal/db/account.go
index 0e1575f9b..058a89859 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -19,6 +19,7 @@
package db
import (
+ "context"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -27,40 +28,43 @@ import (
// Account contains functions related to account getting/setting/creation.
type Account interface {
// GetAccountByID returns one account with the given ID, or an error if something goes wrong.
- GetAccountByID(id string) (*gtsmodel.Account, Error)
+ GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, Error)
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
- GetAccountByURI(uri string) (*gtsmodel.Account, Error)
+ GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
- GetAccountByURL(uri string) (*gtsmodel.Account, Error)
+ GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+
+ // UpdateAccount updates one account by ID.
+ UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
// GetLocalAccountByUsername returns an account on this instance by its username.
- GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error)
+ GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error)
// GetAccountFaves fetches faves/likes created by the target accountID.
- GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error)
+ GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, Error)
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
- CountAccountStatuses(accountID string) (int, Error)
+ CountAccountStatuses(ctx context.Context, accountID string) (int, Error)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
- GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
+ GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
- GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
+ GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
//
// The returned time will be zero if account has never posted anything.
- GetAccountLastPosted(accountID string) (time.Time, Error)
+ GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, Error)
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
- SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
+ SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
// GetInstanceAccount returns the instance account for the given domain.
// If domain is empty, this instance account will be returned.
- GetInstanceAccount(domain string) (*gtsmodel.Account, Error)
+ GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, Error)
}
diff --git a/internal/db/admin.go b/internal/db/admin.go
index aa2b22f47..24d628e84 100644
--- a/internal/db/admin.go
+++ b/internal/db/admin.go
@@ -19,6 +19,7 @@
package db
import (
+ "context"
"net"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -28,26 +29,26 @@ import (
type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
- IsUsernameAvailable(username string) Error
+ IsUsernameAvailable(ctx context.Context, username string) (bool, Error)
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
- IsEmailAvailable(email string) Error
+ IsEmailAvailable(ctx context.Context, email string) (bool, Error)
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
- NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
+ NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
- CreateInstanceAccount() Error
+ CreateInstanceAccount(ctx context.Context) Error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
- CreateInstanceInstance() Error
+ CreateInstanceInstance(ctx context.Context) Error
}
diff --git a/internal/db/basic.go b/internal/db/basic.go
index 729920bba..cf65ddc09 100644
--- a/internal/db/basic.go
+++ b/internal/db/basic.go
@@ -24,15 +24,11 @@ import "context"
type Basic interface {
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
- CreateTable(i interface{}) Error
+ CreateTable(ctx context.Context, i interface{}) Error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
- DropTable(i interface{}) Error
-
- // RegisterTable registers a table for use in many2many relations.
- // For implementations that don't use tables, or many2many relations, this can just return nil.
- RegisterTable(i interface{}) Error
+ DropTable(ctx context.Context, i interface{}) Error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
@@ -45,43 +41,38 @@ type Basic interface {
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetByID(id string, i interface{}) Error
+ GetByID(ctx context.Context, id string, i interface{}) Error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetWhere(where []Where, i interface{}) Error
+ GetWhere(ctx context.Context, where []Where, i interface{}) Error
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetAll(i interface{}) Error
+ GetAll(ctx context.Context, i interface{}) Error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Put(i interface{}) Error
-
- // Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
- // It is up to the implementation to figure out how to store it, and using what key.
- // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Upsert(i interface{}, conflictColumn string) Error
+ Put(ctx context.Context, i interface{}) Error
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- UpdateByID(id string, i interface{}) Error
+ UpdateByID(ctx context.Context, id string, i interface{}) Error
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
- UpdateOneByID(id string, key string, value interface{}, i interface{}) Error
+ UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) Error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
- UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error
+ UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
- DeleteByID(id string, i interface{}) Error
+ DeleteByID(ctx context.Context, id string, i interface{}) Error
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
- DeleteWhere(where []Where, i interface{}) Error
+ DeleteWhere(ctx context.Context, where []Where, i interface{}) Error
}
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
new file mode 100644
index 000000000..7ebb79a15
--- /dev/null
+++ b/internal/db/bundb/account.go
@@ -0,0 +1,291 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type accountDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
+ return a.conn.
+ NewSelect().
+ Model(account).
+ Relation("AvatarMediaAttachment").
+ Relation("HeaderMediaAttachment")
+}
+
+func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
+
+ q := a.newAccountQ(account).
+ Where("account.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
+
+ q := a.newAccountQ(account).
+ Where("account.uri = ?", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
+
+ q := a.newAccountQ(account).
+ Where("account.url = ?", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return account, err
+}
+
+func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
+ if strings.TrimSpace(account.ID) == "" {
+ return nil, errors.New("account had no ID")
+ }
+
+ account.UpdatedAt = time.Now()
+
+ q := a.conn.
+ NewUpdate().
+ Model(account).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return account, err
+}
+
+func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
+
+ q := a.newAccountQ(account)
+
+ if domain == "" {
+ q = q.
+ Where("account.username = ?", domain).
+ Where("account.domain = ?", domain)
+ } else {
+ q = q.
+ Where("account.username = ?", domain).
+ Where("? IS NULL", bun.Ident("domain"))
+ }
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) {
+ status := new(gtsmodel.Status)
+
+ q := a.conn.
+ NewSelect().
+ Model(status).
+ Order("id DESC").
+ Limit(1).
+ Where("account_id = ?", accountID).
+ Column("created_at")
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return status.CreatedAt, err
+}
+
+func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
+ if mediaAttachment.Avatar && mediaAttachment.Header {
+ return errors.New("one media attachment cannot be both header and avatar")
+ }
+
+ var headerOrAVI string
+ if mediaAttachment.Avatar {
+ headerOrAVI = "avatar"
+ } else if mediaAttachment.Header {
+ headerOrAVI = "header"
+ } else {
+ return errors.New("given media attachment was neither a header nor an avatar")
+ }
+
+ // TODO: there are probably more side effects here that need to be handled
+ if _, err := a.conn.
+ NewInsert().
+ Model(mediaAttachment).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ if _, err := a.conn.
+ NewUpdate().
+ Model(>smodel.Account{}).
+ Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
+ Where("id = ?", accountID).
+ Exec(ctx); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, db.Error) {
+ account := new(gtsmodel.Account)
+
+ q := a.newAccountQ(account).
+ Where("username = ?", username).
+ Where("? IS NULL", bun.Ident("domain"))
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return account, err
+}
+
+func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
+ faves := new([]*gtsmodel.StatusFave)
+
+ if err := a.conn.
+ NewSelect().
+ Model(faves).
+ Where("account_id = ?", accountID).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return *faves, nil
+}
+
+func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
+ return a.conn.
+ NewSelect().
+ Model(>smodel.Status{}).
+ Where("account_id = ?", accountID).
+ Count(ctx)
+}
+
+func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
+ a.log.Debugf("getting statuses for account %s", accountID)
+ statuses := []*gtsmodel.Status{}
+
+ q := a.conn.
+ NewSelect().
+ Model(&statuses).
+ Order("id DESC")
+
+ if accountID != "" {
+ q = q.Where("account_id = ?", accountID)
+ }
+
+ if limit != 0 {
+ q = q.Limit(limit)
+ }
+
+ if excludeReplies {
+ q = q.Where("? IS NULL", bun.Ident("in_reply_to_id"))
+ }
+
+ if pinnedOnly {
+ q = q.Where("pinned = ?", true)
+ }
+
+ if maxID != "" {
+ q = q.Where("id < ?", maxID)
+ }
+
+ if mediaOnly {
+ q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
+ return q.
+ WhereOr("? IS NOT NULL", bun.Ident("attachments")).
+ WhereOr("attachments != '{}'")
+ })
+ }
+
+ if err := q.Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ if len(statuses) == 0 {
+ return nil, db.ErrNoEntries
+ }
+
+ a.log.Debugf("returning statuses for account %s", accountID)
+ return statuses, nil
+}
+
+func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
+ blocks := []*gtsmodel.Block{}
+
+ fq := a.conn.
+ NewSelect().
+ Model(&blocks).
+ Where("block.account_id = ?", accountID).
+ Relation("TargetAccount").
+ Order("block.id DESC")
+
+ if maxID != "" {
+ fq = fq.Where("block.id < ?", maxID)
+ }
+
+ if sinceID != "" {
+ fq = fq.Where("block.id > ?", sinceID)
+ }
+
+ if limit > 0 {
+ fq = fq.Limit(limit)
+ }
+
+ err := fq.Scan(ctx)
+ if err != nil {
+ return nil, "", "", err
+ }
+
+ if len(blocks) == 0 {
+ return nil, "", "", db.ErrNoEntries
+ }
+
+ accounts := []*gtsmodel.Account{}
+ for _, b := range blocks {
+ accounts = append(accounts, b.TargetAccount)
+ }
+
+ nextMaxID := blocks[len(blocks)-1].ID
+ prevMinID := blocks[0].ID
+ return accounts, nextMaxID, prevMinID, nil
+}
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
new file mode 100644
index 000000000..7174b781d
--- /dev/null
+++ b/internal/db/bundb/account_test.go
@@ -0,0 +1,86 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type AccountTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *AccountTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *AccountTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *AccountTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
+ account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(account)
+ suite.NotNil(account.AvatarMediaAttachment)
+ suite.NotEmpty(account.AvatarMediaAttachment.URL)
+ suite.NotNil(account.HeaderMediaAttachment)
+ suite.NotEmpty(account.HeaderMediaAttachment.URL)
+}
+
+func (suite *AccountTestSuite) TestUpdateAccount() {
+ testAccount := suite.testAccounts["local_account_1"]
+
+ testAccount.DisplayName = "new display name!"
+
+ _, err := suite.db.UpdateAccount(context.Background(), testAccount)
+ suite.NoError(err)
+
+ updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID)
+ suite.NoError(err)
+ suite.Equal("new display name!", updated.DisplayName)
+ suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
+}
+
+func TestAccountTestSuite(t *testing.T) {
+ suite.Run(t, new(AccountTestSuite))
+}
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go
new file mode 100644
index 000000000..67a1e8a0d
--- /dev/null
+++ b/internal/db/bundb/admin.go
@@ -0,0 +1,272 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "database/sql"
+ "fmt"
+ "net"
+ "net/mail"
+ "strings"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
+ "golang.org/x/crypto/bcrypt"
+)
+
+type adminDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
+ q := a.conn.
+ NewSelect().
+ Model(>smodel.Account{}).
+ Where("username = ?", username).
+ Where("domain = ?", nil)
+
+ return notExists(ctx, q)
+}
+
+func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
+ // parse the domain from the email
+ m, err := mail.ParseAddress(email)
+ if err != nil {
+ return false, fmt.Errorf("error parsing email address %s: %s", email, err)
+ }
+ domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
+
+ // check if the email domain is blocked
+ if err := a.conn.
+ NewSelect().
+ Model(>smodel.EmailDomainBlock{}).
+ Where("domain = ?", domain).
+ Scan(ctx); err == nil {
+ // fail because we found something
+ return false, fmt.Errorf("email domain %s is blocked", domain)
+ } else if err != sql.ErrNoRows {
+ return false, processErrorResponse(err)
+ }
+
+ // check if this email is associated with a user already
+ q := a.conn.
+ NewSelect().
+ Model(>smodel.User{}).
+ Where("email = ?", email).
+ WhereOr("unconfirmed_email = ?", email)
+
+ return notExists(ctx, q)
+}
+
+func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ a.log.Errorf("error creating new rsa key: %s", err)
+ return nil, err
+ }
+
+ // if something went wrong while creating a user, we might already have an account, so check here first...
+ acct := >smodel.Account{}
+ err = a.conn.NewSelect().
+ Model(acct).
+ Where("username = ?", username).
+ Where("? IS NULL", bun.Ident("domain")).
+ Scan(ctx)
+ if err != nil {
+ // we just don't have an account yet create one
+ newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
+ newAccountID, err := id.NewRandomULID()
+ if err != nil {
+ return nil, err
+ }
+
+ acct = >smodel.Account{
+ ID: newAccountID,
+ Username: username,
+ DisplayName: username,
+ Reason: reason,
+ URL: newAccountURIs.UserURL,
+ PrivateKey: key,
+ PublicKey: &key.PublicKey,
+ PublicKeyURI: newAccountURIs.PublicKeyURI,
+ ActorType: gtsmodel.ActivityStreamsPerson,
+ URI: newAccountURIs.UserURI,
+ InboxURI: newAccountURIs.InboxURI,
+ OutboxURI: newAccountURIs.OutboxURI,
+ FollowersURI: newAccountURIs.FollowersURI,
+ FollowingURI: newAccountURIs.FollowingURI,
+ FeaturedCollectionURI: newAccountURIs.CollectionURI,
+ }
+ if _, err = a.conn.
+ NewInsert().
+ Model(acct).
+ Exec(ctx); err != nil {
+ return nil, err
+ }
+ }
+
+ pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
+ if err != nil {
+ return nil, fmt.Errorf("error hashing password: %s", err)
+ }
+
+ newUserID, err := id.NewRandomULID()
+ if err != nil {
+ return nil, err
+ }
+
+ u := >smodel.User{
+ ID: newUserID,
+ AccountID: acct.ID,
+ EncryptedPassword: string(pw),
+ SignUpIP: signUpIP.To4(),
+ Locale: locale,
+ UnconfirmedEmail: email,
+ CreatedByApplicationID: appID,
+ Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
+ }
+
+ if emailVerified {
+ u.ConfirmedAt = time.Now()
+ u.Email = email
+ }
+
+ if admin {
+ u.Admin = true
+ u.Moderator = true
+ }
+
+ if _, err = a.conn.
+ NewInsert().
+ Model(u).
+ Exec(ctx); err != nil {
+ return nil, err
+ }
+
+ return u, nil
+}
+
+func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
+ username := a.config.Host
+
+ // check if instance account already exists
+ existsQ := a.conn.
+ NewSelect().
+ Model(>smodel.Account{}).
+ Where("username = ?", username).
+ Where("? IS NULL", bun.Ident("domain"))
+ count, err := existsQ.Count(ctx)
+ if err != nil && count == 1 {
+ a.log.Infof("instance account %s already exists", username)
+ return nil
+ } else if err != sql.ErrNoRows {
+ return processErrorResponse(err)
+ }
+
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ a.log.Errorf("error creating new rsa key: %s", err)
+ return err
+ }
+
+ aID, err := id.NewRandomULID()
+ if err != nil {
+ return err
+ }
+
+ newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
+ acct := >smodel.Account{
+ ID: aID,
+ Username: a.config.Host,
+ DisplayName: username,
+ URL: newAccountURIs.UserURL,
+ PrivateKey: key,
+ PublicKey: &key.PublicKey,
+ PublicKeyURI: newAccountURIs.PublicKeyURI,
+ ActorType: gtsmodel.ActivityStreamsPerson,
+ URI: newAccountURIs.UserURI,
+ InboxURI: newAccountURIs.InboxURI,
+ OutboxURI: newAccountURIs.OutboxURI,
+ FollowersURI: newAccountURIs.FollowersURI,
+ FollowingURI: newAccountURIs.FollowingURI,
+ FeaturedCollectionURI: newAccountURIs.CollectionURI,
+ }
+
+ insertQ := a.conn.
+ NewInsert().
+ Model(acct)
+
+ if _, err := insertQ.Exec(ctx); err != nil {
+ return err
+ }
+
+ a.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
+ return nil
+}
+
+func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
+ domain := a.config.Host
+
+ // check if instance entry already exists
+ existsQ := a.conn.
+ NewSelect().
+ Model(>smodel.Instance{}).
+ Where("domain = ?", domain)
+
+ count, err := existsQ.Count(ctx)
+ if err != nil && count == 1 {
+ a.log.Infof("instance instance %s already exists", domain)
+ return nil
+ } else if err != sql.ErrNoRows {
+ return processErrorResponse(err)
+ }
+
+ iID, err := id.NewRandomULID()
+ if err != nil {
+ return err
+ }
+
+ i := >smodel.Instance{
+ ID: iID,
+ Domain: domain,
+ Title: domain,
+ URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
+ }
+
+ insertQ := a.conn.
+ NewInsert().
+ Model(i)
+
+ if _, err := insertQ.Exec(ctx); err != nil {
+ return err
+ }
+ a.log.Infof("created instance instance %s with id %s", domain, i.ID)
+ return nil
+}
diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go
new file mode 100644
index 000000000..983b6b810
--- /dev/null
+++ b/internal/db/bundb/basic.go
@@ -0,0 +1,179 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "strings"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/uptrace/bun"
+)
+
+type basicDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewInsert().Model(i).Exec(ctx)
+ if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists
+ }
+ return err
+}
+
+func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewSelect().
+ Model(i).
+ Where("id = ?", id)
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := b.conn.NewSelect().Model(i)
+ for _, w := range where {
+
+ if w.Value == nil {
+ q = q.Where("? IS NULL", bun.Ident(w.Key))
+ } else {
+ if w.CaseInsensitive {
+ q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
+ } else {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+ }
+ }
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
+ q := b.conn.
+ NewSelect().
+ Model(i)
+
+ return processErrorResponse(q.Scan(ctx))
+}
+
+func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewDelete().
+ Model(i).
+ Where("id = ?", id)
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := b.conn.
+ NewDelete().
+ Model(i)
+
+ for _, w := range where {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
+ q := b.conn.
+ NewUpdate().
+ Model(i).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
+ q := b.conn.NewUpdate().
+ Model(i).
+ Set("? = ?", bun.Safe(key), value).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
+ q := b.conn.NewUpdate().Model(i)
+
+ for _, w := range where {
+ if w.Value == nil {
+ q = q.Where("? IS NULL", bun.Ident(w.Key))
+ } else {
+ if w.CaseInsensitive {
+ q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
+ } else {
+ q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
+ }
+ }
+ }
+
+ q = q.Set("? = ?", bun.Safe(key), value)
+
+ _, err := q.Exec(ctx)
+
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx)
+ return err
+}
+
+func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
+ _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
+ return processErrorResponse(err)
+}
+
+func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
+ return b.conn.Ping()
+}
+
+func (b *basicDB) Stop(ctx context.Context) db.Error {
+ b.log.Info("closing db connection")
+ if err := b.conn.Close(); err != nil {
+ // only cancel if there's a problem closing the db
+ return err
+ }
+ return nil
+}
diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go
new file mode 100644
index 000000000..9189618c9
--- /dev/null
+++ b/internal/db/bundb/basic_test.go
@@ -0,0 +1,68 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type BasicTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *BasicTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *BasicTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *BasicTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *BasicTestSuite) TestGetAccountByID() {
+ testAccount := suite.testAccounts["local_account_1"]
+
+ a := >smodel.Account{}
+ err := suite.db.GetByID(context.Background(), testAccount.ID, a)
+ suite.NoError(err)
+}
+
+func TestBasicTestSuite(t *testing.T) {
+ suite.Run(t, new(BasicTestSuite))
+}
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
new file mode 100644
index 000000000..49ed09cbd
--- /dev/null
+++ b/internal/db/bundb/bundb.go
@@ -0,0 +1,410 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "database/sql"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/jackc/pgx/v4"
+ "github.com/jackc/pgx/v4/stdlib"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/uptrace/bun"
+ "github.com/uptrace/bun/dialect/pgdialect"
+)
+
+const (
+ dbTypePostgres = "postgres"
+ dbTypeSqlite = "sqlite"
+)
+
+var registerTables []interface{} = []interface{}{
+ >smodel.StatusToEmoji{},
+ >smodel.StatusToTag{},
+}
+
+// bunDBService satisfies the DB interface
+type bunDBService struct {
+ db.Account
+ db.Admin
+ db.Basic
+ db.Domain
+ db.Instance
+ db.Media
+ db.Mention
+ db.Notification
+ db.Relationship
+ db.Session
+ db.Status
+ db.Timeline
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
+// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
+func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
+ var sqldb *sql.DB
+ var conn *bun.DB
+
+ // depending on the database type we're trying to create, we need to use a different driver...
+ switch strings.ToLower(c.DBConfig.Type) {
+ case dbTypePostgres:
+ // POSTGRES
+ opts, err := deriveBunDBPGOptions(c)
+ if err != nil {
+ return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
+ }
+ sqldb = stdlib.OpenDB(*opts)
+ conn = bun.NewDB(sqldb, pgdialect.New())
+ case dbTypeSqlite:
+ // SQLITE
+ // TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
+ default:
+ return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
+ }
+
+ // actually *begin* the connection so that we can tell if the db is there and listening
+ if err := conn.Ping(); err != nil {
+ return nil, fmt.Errorf("db connection error: %s", err)
+ }
+ log.Info("connected to database")
+
+ for _, t := range registerTables {
+ // https://bun.uptrace.dev/orm/many-to-many-relation/
+ conn.RegisterModel(t)
+ }
+
+ ps := &bunDBService{
+ Account: &accountDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Admin: &adminDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Basic: &basicDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Domain: &domainDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Instance: &instanceDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Media: &mediaDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Mention: &mentionDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Notification: ¬ificationDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Relationship: &relationshipDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Session: &sessionDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Status: &statusDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ Timeline: &timelineDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
+ config: c,
+ conn: conn,
+ log: log,
+ }
+
+ // we can confidently return this useable service now
+ return ps, nil
+}
+
+/*
+ HANDY STUFF
+*/
+
+// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
+// with sensible defaults, or an error if it's not satisfied by the provided config.
+func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) {
+ if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
+ return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
+ }
+
+ // validate port
+ if c.DBConfig.Port == 0 {
+ return nil, errors.New("no port set")
+ }
+
+ // validate address
+ if c.DBConfig.Address == "" {
+ return nil, errors.New("no address set")
+ }
+
+ // validate username
+ if c.DBConfig.User == "" {
+ return nil, errors.New("no user set")
+ }
+
+ // validate that there's a password
+ if c.DBConfig.Password == "" {
+ return nil, errors.New("no password set")
+ }
+
+ // validate database
+ if c.DBConfig.Database == "" {
+ return nil, errors.New("no database set")
+ }
+
+ var tlsConfig *tls.Config
+ switch c.DBConfig.TLSMode {
+ case config.DBTLSModeDisable, config.DBTLSModeUnset:
+ break // nothing to do
+ case config.DBTLSModeEnable:
+ tlsConfig = &tls.Config{
+ InsecureSkipVerify: true,
+ }
+ case config.DBTLSModeRequire:
+ tlsConfig = &tls.Config{
+ InsecureSkipVerify: false,
+ ServerName: c.DBConfig.Address,
+ }
+ }
+
+ if tlsConfig != nil && c.DBConfig.TLSCACert != "" {
+ // load the system cert pool first -- we'll append the given CA cert to this
+ certPool, err := x509.SystemCertPool()
+ if err != nil {
+ return nil, fmt.Errorf("error fetching system CA cert pool: %s", err)
+ }
+
+ // open the file itself and make sure there's something in it
+ caCertBytes, err := os.ReadFile(c.DBConfig.TLSCACert)
+ if err != nil {
+ return nil, fmt.Errorf("error opening CA certificate at %s: %s", c.DBConfig.TLSCACert, err)
+ }
+ if len(caCertBytes) == 0 {
+ return nil, fmt.Errorf("ca cert at %s was empty", c.DBConfig.TLSCACert)
+ }
+
+ // make sure we have a PEM block
+ caPem, _ := pem.Decode(caCertBytes)
+ if caPem == nil {
+ return nil, fmt.Errorf("could not parse cert at %s into PEM", c.DBConfig.TLSCACert)
+ }
+
+ // parse the PEM block into the certificate
+ caCert, err := x509.ParseCertificate(caPem.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", c.DBConfig.TLSCACert, err)
+ }
+
+ // we're happy, add it to the existing pool and then use this pool in our tls config
+ certPool.AddCert(caCert)
+ tlsConfig.RootCAs = certPool
+ }
+
+ cfg, _ := pgx.ParseConfig("")
+ cfg.Host = c.DBConfig.Address
+ cfg.Port = uint16(c.DBConfig.Port)
+ cfg.User = c.DBConfig.User
+ cfg.Password = c.DBConfig.Password
+ cfg.TLSConfig = tlsConfig
+ cfg.Database = c.DBConfig.Database
+ cfg.PreferSimpleProtocol = true
+
+ return cfg, nil
+}
+
+/*
+ CONVERSION FUNCTIONS
+*/
+
+// TODO: move these to the type converter, it's bananas that they're here and not there
+
+func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
+ ogAccount := >smodel.Account{}
+ if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ menchies := []*gtsmodel.Mention{}
+ for _, a := range targetAccounts {
+ // A mentioned account looks like "@test@example.org" or just "@test" for a local account
+ // -- we can guarantee this from the regex that targetAccounts should have been derived from.
+ // But we still need to do a bit of fiddling to get what we need here -- the username and domain (if given).
+
+ // 1. trim off the first @
+ t := strings.TrimPrefix(a, "@")
+
+ // 2. split the username and domain
+ s := strings.Split(t, "@")
+
+ // 3. if it's length 1 it's a local account, length 2 means remote, anything else means something is wrong
+ var local bool
+ switch len(s) {
+ case 1:
+ local = true
+ case 2:
+ local = false
+ default:
+ return nil, fmt.Errorf("mentioned account format '%s' was not valid", a)
+ }
+
+ var username, domain string
+ username = s[0]
+ if !local {
+ domain = s[1]
+ }
+
+ // 4. check we now have a proper username and domain
+ if username == "" || (!local && domain == "") {
+ return nil, fmt.Errorf("username or domain for '%s' was nil", a)
+ }
+
+ // okay we're good now, we can start pulling accounts out of the database
+ mentionedAccount := >smodel.Account{}
+ var err error
+
+ // match username + account, case insensitive
+ if local {
+ // local user -- should have a null domain
+ err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx)
+ } else {
+ // remote user -- should have domain defined
+ err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx)
+ }
+
+ if err != nil {
+ if err == sql.ErrNoRows {
+ // no result found for this username/domain so just don't include it as a mencho and carry on about our business
+ ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
+ continue
+ }
+ // a serious error has happened so bail
+ return nil, fmt.Errorf("error getting account with username '%s' and domain '%s': %s", username, domain, err)
+ }
+
+ // id, createdAt and updatedAt will be populated by the db, so we have everything we need!
+ menchies = append(menchies, >smodel.Mention{
+ StatusID: statusID,
+ OriginAccountID: ogAccount.ID,
+ OriginAccountURI: ogAccount.URI,
+ TargetAccountID: mentionedAccount.ID,
+ NameString: a,
+ TargetAccountURI: mentionedAccount.URI,
+ TargetAccountURL: mentionedAccount.URL,
+ OriginAccount: mentionedAccount,
+ })
+ }
+ return menchies, nil
+}
+
+func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
+ newTags := []*gtsmodel.Tag{}
+ for _, t := range tags {
+ tag := >smodel.Tag{}
+ // we can use selectorinsert here to create the new tag if it doesn't exist already
+ // inserted will be true if this is a new tag we just created
+ if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
+ if err == sql.ErrNoRows {
+ // tag doesn't exist yet so populate it
+ newID, err := id.NewRandomULID()
+ if err != nil {
+ return nil, err
+ }
+ tag.ID = newID
+ tag.URL = fmt.Sprintf("%s://%s/tags/%s", ps.config.Protocol, ps.config.Host, t)
+ tag.Name = t
+ tag.FirstSeenFromAccountID = originAccountID
+ tag.CreatedAt = time.Now()
+ tag.UpdatedAt = time.Now()
+ tag.Useable = true
+ tag.Listable = true
+ } else {
+ return nil, fmt.Errorf("error getting tag with name %s: %s", t, err)
+ }
+ }
+
+ // bail already if the tag isn't useable
+ if !tag.Useable {
+ continue
+ }
+ tag.LastStatusAt = time.Now()
+ newTags = append(newTags, tag)
+ }
+ return newTags, nil
+}
+
+func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
+ newEmojis := []*gtsmodel.Emoji{}
+ for _, e := range emojis {
+ emoji := >smodel.Emoji{}
+ err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ // no result found for this username/domain so just don't include it as an emoji and carry on about our business
+ ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
+ continue
+ }
+ // a serious error has happened so bail
+ return nil, fmt.Errorf("error getting emoji with shortcode %s: %s", e, err)
+ }
+ newEmojis = append(newEmojis, emoji)
+ }
+ return newEmojis, nil
+}
diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go
new file mode 100644
index 000000000..b789375af
--- /dev/null
+++ b/internal/db/bundb/bundb_test.go
@@ -0,0 +1,47 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb_test
+
+import (
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+)
+
+type BunDBStandardTestSuite struct {
+ // standard suite interfaces
+ suite.Suite
+ config *config.Config
+ db db.DB
+ log *logrus.Logger
+
+ // standard suite models
+ testTokens map[string]*oauth.Token
+ testClients map[string]*oauth.Client
+ testApplications map[string]*gtsmodel.Application
+ testUsers map[string]*gtsmodel.User
+ testAccounts map[string]*gtsmodel.Account
+ testAttachments map[string]*gtsmodel.MediaAttachment
+ testStatuses map[string]*gtsmodel.Status
+ testTags map[string]*gtsmodel.Tag
+ testMentions map[string]*gtsmodel.Mention
+}
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
new file mode 100644
index 000000000..6aa2b8ffe
--- /dev/null
+++ b/internal/db/bundb/domain.go
@@ -0,0 +1,81 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+ "net/url"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/uptrace/bun"
+)
+
+type domainDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
+ if domain == "" {
+ return false, nil
+ }
+
+ q := d.conn.
+ NewSelect().
+ Model(>smodel.DomainBlock{}).
+ Where("LOWER(domain) = LOWER(?)", domain).
+ Limit(1)
+
+ return exists(ctx, q)
+}
+
+func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
+ // filter out any doubles
+ uniqueDomains := util.UniqueStrings(domains)
+
+ for _, domain := range uniqueDomains {
+ if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil {
+ return false, err
+ } else if blocked {
+ return blocked, nil
+ }
+ }
+
+ // no blocks found
+ return false, nil
+}
+
+func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) {
+ domain := uri.Hostname()
+ return d.IsDomainBlocked(ctx, domain)
+}
+
+func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) {
+ domains := []string{}
+ for _, uri := range uris {
+ domains = append(domains, uri.Hostname())
+ }
+
+ return d.AreDomainsBlocked(ctx, domains)
+}
diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go
new file mode 100644
index 000000000..f9364346e
--- /dev/null
+++ b/internal/db/bundb/instance.go
@@ -0,0 +1,118 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type instanceDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Account{})
+
+ if domain == i.config.Host {
+ // if the domain is *this* domain, just count where the domain field is null
+ q = q.Where("? IS NULL", bun.Ident("domain"))
+ } else {
+ q = q.Where("domain = ?", domain)
+ }
+
+ // don't count the instance account or suspended users
+ q = q.
+ Where("username != ?", domain).
+ Where("? IS NULL", bun.Ident("suspended_at"))
+
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
+}
+
+func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Status{})
+
+ if domain == i.config.Host {
+ // if the domain is *this* domain, just count where local is true
+ q = q.Where("local = ?", true)
+ } else {
+ // join on the domain of the account
+ q = q.Join("JOIN accounts AS account ON account.id = status.account_id").
+ Where("account.domain = ?", domain)
+ }
+
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
+}
+
+func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
+ q := i.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Instance{})
+
+ if domain == i.config.Host {
+ // if the domain is *this* domain, just count other instances it knows about
+ // exclude domains that are blocked
+ q = q.Where("domain != ?", domain).Where("? IS NULL", bun.Ident("suspended_at"))
+ } else {
+ // TODO: implement federated domain counting properly for remote domains
+ return 0, nil
+ }
+
+ count, err := q.Count(ctx)
+
+ return count, processErrorResponse(err)
+}
+
+func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
+ i.log.Debug("GetAccountsForInstance")
+
+ accounts := []*gtsmodel.Account{}
+
+ q := i.conn.NewSelect().
+ Model(&accounts).
+ Where("domain = ?", domain).
+ Order("id DESC")
+
+ if maxID != "" {
+ q = q.Where("id < ?", maxID)
+ }
+
+ if limit > 0 {
+ q = q.Limit(limit)
+ }
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return accounts, err
+}
diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go
new file mode 100644
index 000000000..04e55ca62
--- /dev/null
+++ b/internal/db/bundb/media.go
@@ -0,0 +1,53 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type mediaDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {
+ return m.conn.
+ NewSelect().
+ Model(i).
+ Relation("Account")
+}
+
+func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) {
+ attachment := >smodel.MediaAttachment{}
+
+ q := m.newMediaQ(attachment).
+ Where("media_attachment.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return attachment, err
+}
diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go
new file mode 100644
index 000000000..a444f9b5f
--- /dev/null
+++ b/internal/db/bundb/mention.go
@@ -0,0 +1,108 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type mentionDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+ cache cache.Cache
+}
+
+func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) {
+ if m.cache == nil {
+ m.cache = cache.New()
+ }
+
+ if err := m.cache.Store(id, mention); err != nil {
+ m.log.Panicf("mentionDB: error storing in cache: %s", err)
+ }
+}
+
+func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
+ if m.cache == nil {
+ m.cache = cache.New()
+ return nil, false
+ }
+
+ mI, err := m.cache.Fetch(id)
+ if err != nil || mI == nil {
+ return nil, false
+ }
+
+ mention, ok := mI.(*gtsmodel.Mention)
+ if !ok {
+ m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id)
+ }
+
+ return mention, true
+}
+
+func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
+ return m.conn.
+ NewSelect().
+ Model(i).
+ Relation("Status").
+ Relation("OriginAccount").
+ Relation("TargetAccount")
+}
+
+func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
+ if mention, cached := m.mentionCached(id); cached {
+ return mention, nil
+ }
+
+ mention := >smodel.Mention{}
+
+ q := m.newMentionQ(mention).
+ Where("mention.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err == nil && mention != nil {
+ m.cacheMention(id, mention)
+ }
+
+ return mention, err
+}
+
+func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
+ mentions := []*gtsmodel.Mention{}
+
+ for _, i := range ids {
+ mention, err := m.GetMention(ctx, i)
+ if err != nil {
+ return nil, processErrorResponse(err)
+ }
+ mentions = append(mentions, mention)
+ }
+
+ return mentions, nil
+}
diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go
new file mode 100644
index 000000000..1c30837ec
--- /dev/null
+++ b/internal/db/bundb/notification.go
@@ -0,0 +1,136 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type notificationDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+ cache cache.Cache
+}
+
+func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) {
+ if n.cache == nil {
+ n.cache = cache.New()
+ }
+
+ if err := n.cache.Store(id, notification); err != nil {
+ n.log.Panicf("notificationDB: error storing in cache: %s", err)
+ }
+}
+
+func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) {
+ if n.cache == nil {
+ n.cache = cache.New()
+ return nil, false
+ }
+
+ nI, err := n.cache.Fetch(id)
+ if err != nil || nI == nil {
+ return nil, false
+ }
+
+ notification, ok := nI.(*gtsmodel.Notification)
+ if !ok {
+ n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id)
+ }
+
+ return notification, true
+}
+
+func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
+ return n.conn.
+ NewSelect().
+ Model(i).
+ Relation("OriginAccount").
+ Relation("TargetAccount").
+ Relation("Status")
+}
+
+func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
+ if notification, cached := n.notificationCached(id); cached {
+ return notification, nil
+ }
+
+ notification := >smodel.Notification{}
+
+ q := n.newNotificationQ(notification).
+ Where("notification.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err == nil && notification != nil {
+ n.cacheNotification(id, notification)
+ }
+
+ return notification, err
+}
+
+func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
+ // begin by selecting just the IDs
+ notifIDs := []*gtsmodel.Notification{}
+ q := n.conn.
+ NewSelect().
+ Model(¬ifIDs).
+ Column("id").
+ Where("target_account_id = ?", accountID).
+ Order("id DESC")
+
+ if maxID != "" {
+ q = q.Where("id < ?", maxID)
+ }
+
+ if sinceID != "" {
+ q = q.Where("id > ?", sinceID)
+ }
+
+ if limit != 0 {
+ q = q.Limit(limit)
+ }
+
+ err := processErrorResponse(q.Scan(ctx))
+ if err != nil {
+ return nil, err
+ }
+
+ // now we have the IDs, select the notifs one by one
+ // reason for this is that for each notif, we can instead get it from our cache if it's cached
+ notifications := []*gtsmodel.Notification{}
+ for _, notifID := range notifIDs {
+ notif, err := n.GetNotification(ctx, notifID.ID)
+ errP := processErrorResponse(err)
+ if errP != nil {
+ return nil, errP
+ }
+ notifications = append(notifications, notif)
+ }
+
+ return notifications, nil
+}
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
new file mode 100644
index 000000000..ccc604baf
--- /dev/null
+++ b/internal/db/bundb/relationship.go
@@ -0,0 +1,328 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type relationshipDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {
+ return r.conn.
+ NewSelect().
+ Model(block).
+ Relation("Account").
+ Relation("TargetAccount")
+}
+
+func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
+ return r.conn.
+ NewSelect().
+ Model(follow).
+ Relation("Account").
+ Relation("TargetAccount")
+}
+
+func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
+ q := r.conn.
+ NewSelect().
+ Model(>smodel.Block{}).
+ Where("account_id = ?", account1).
+ Where("target_account_id = ?", account2).
+ Limit(1)
+
+ if eitherDirection {
+ q = q.
+ WhereOr("target_account_id = ?", account1).
+ Where("account_id = ?", account2)
+ }
+
+ return exists(ctx, q)
+}
+
+func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
+ block := >smodel.Block{}
+
+ q := r.newBlockQ(block).
+ Where("block.account_id = ?", account1).
+ Where("block.target_account_id = ?", account2)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return block, err
+}
+
+func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
+ rel := >smodel.Relationship{
+ ID: targetAccount,
+ }
+
+ // check if the requesting account follows the target account
+ follow := >smodel.Follow{}
+ if err := r.conn.
+ NewSelect().
+ Model(follow).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Scan(ctx); err != nil {
+ if err != sql.ErrNoRows {
+ // a proper error
+ return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
+ }
+ // no follow exists so these are all false
+ rel.Following = false
+ rel.ShowingReblogs = false
+ rel.Notifying = false
+ } else {
+ // follow exists so we can fill these fields out...
+ rel.Following = true
+ rel.ShowingReblogs = follow.ShowReblogs
+ rel.Notifying = follow.Notify
+ }
+
+ // check if the target account follows the requesting account
+ count, err := r.conn.
+ NewSelect().
+ Model(>smodel.Follow{}).
+ Where("account_id = ?", targetAccount).
+ Where("target_account_id = ?", requestingAccount).
+ Limit(1).
+ Count(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
+ }
+ rel.FollowedBy = count > 0
+
+ // check if the requesting account blocks the target account
+ count, err = r.conn.NewSelect().
+ Model(>smodel.Block{}).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Count(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
+ }
+ rel.Blocking = count > 0
+
+ // check if the target account blocks the requesting account
+ count, err = r.conn.
+ NewSelect().
+ Model(>smodel.Block{}).
+ Where("account_id = ?", targetAccount).
+ Where("target_account_id = ?", requestingAccount).
+ Limit(1).
+ Count(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ }
+ rel.BlockedBy = count > 0
+
+ // check if there's a pending following request from requesting account to target account
+ count, err = r.conn.
+ NewSelect().
+ Model(>smodel.FollowRequest{}).
+ Where("account_id = ?", requestingAccount).
+ Where("target_account_id = ?", targetAccount).
+ Limit(1).
+ Count(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ }
+ rel.Requested = count > 0
+
+ return rel, nil
+}
+
+func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
+ if sourceAccount == nil || targetAccount == nil {
+ return false, nil
+ }
+
+ q := r.conn.
+ NewSelect().
+ Model(>smodel.Follow{}).
+ Where("account_id = ?", sourceAccount.ID).
+ Where("target_account_id = ?", targetAccount.ID).
+ Limit(1)
+
+ return exists(ctx, q)
+}
+
+func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
+ if sourceAccount == nil || targetAccount == nil {
+ return false, nil
+ }
+
+ q := r.conn.
+ NewSelect().
+ Model(>smodel.FollowRequest{}).
+ Where("account_id = ?", sourceAccount.ID).
+ Where("target_account_id = ?", targetAccount.ID)
+
+ return exists(ctx, q)
+}
+
+func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
+ if account1 == nil || account2 == nil {
+ return false, nil
+ }
+
+ // make sure account 1 follows account 2
+ f1, err := r.IsFollowing(ctx, account1, account2)
+ if err != nil {
+ return false, processErrorResponse(err)
+ }
+
+ // make sure account 2 follows account 1
+ f2, err := r.IsFollowing(ctx, account2, account1)
+ if err != nil {
+ return false, processErrorResponse(err)
+ }
+
+ return f1 && f2, nil
+}
+
+func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
+ // make sure the original follow request exists
+ fr := >smodel.FollowRequest{}
+ if err := r.conn.
+ NewSelect().
+ Model(fr).
+ Where("account_id = ?", originAccountID).
+ Where("target_account_id = ?", targetAccountID).
+ Scan(ctx); err != nil {
+ return nil, processErrorResponse(err)
+ }
+
+ // create a new follow to 'replace' the request with
+ follow := >smodel.Follow{
+ ID: fr.ID,
+ AccountID: originAccountID,
+ TargetAccountID: targetAccountID,
+ URI: fr.URI,
+ }
+
+ // if the follow already exists, just update the URI -- we don't need to do anything else
+ if _, err := r.conn.
+ NewInsert().
+ Model(follow).
+ On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
+ Exec(ctx); err != nil {
+ return nil, processErrorResponse(err)
+ }
+
+ // now remove the follow request
+ if _, err := r.conn.
+ NewDelete().
+ Model(>smodel.FollowRequest{}).
+ Where("account_id = ?", originAccountID).
+ Where("target_account_id = ?", targetAccountID).
+ Exec(ctx); err != nil {
+ return nil, processErrorResponse(err)
+ }
+
+ return follow, nil
+}
+
+func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
+ followRequests := []*gtsmodel.FollowRequest{}
+
+ q := r.newFollowQ(&followRequests).
+ Where("target_account_id = ?", accountID)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return followRequests, err
+}
+
+func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) {
+ follows := []*gtsmodel.Follow{}
+
+ q := r.newFollowQ(&follows).
+ Where("account_id = ?", accountID)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ return follows, err
+}
+
+func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
+ return r.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Follow{}).
+ Where("account_id = ?", accountID).
+ Count(ctx)
+}
+
+func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
+
+ follows := []*gtsmodel.Follow{}
+
+ q := r.conn.
+ NewSelect().
+ Model(&follows)
+
+ if localOnly {
+ // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
+ whereGroup := func(q *bun.SelectQuery) *bun.SelectQuery {
+ q = q.
+ WhereOr("? IS NULL", bun.Ident("a.domain")).
+ WhereOr("a.domain = ?", "")
+ return q
+ }
+
+ q = q.ColumnExpr("follow.*").
+ Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
+ Where("follow.target_account_id = ?", accountID).
+ WhereGroup(" AND ", whereGroup)
+ } else {
+ q = q.Where("target_account_id = ?", accountID)
+ }
+
+ if err := q.Scan(ctx); err != nil {
+ if err == sql.ErrNoRows {
+ return follows, nil
+ }
+ return nil, processErrorResponse(err)
+ }
+ return follows, nil
+}
+
+func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
+ return r.conn.
+ NewSelect().
+ Model(&[]*gtsmodel.Follow{}).
+ Where("target_account_id = ?", accountID).
+ Count(ctx)
+}
diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go
new file mode 100644
index 000000000..87e20673d
--- /dev/null
+++ b/internal/db/bundb/session.go
@@ -0,0 +1,85 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+ "crypto/rand"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/uptrace/bun"
+)
+
+type sessionDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
+ rs := new(gtsmodel.RouterSession)
+
+ q := s.conn.
+ NewSelect().
+ Model(rs).
+ Limit(1)
+
+ _, err := q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return rs, err
+}
+
+func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
+ auth := make([]byte, 32)
+ crypt := make([]byte, 32)
+
+ if _, err := rand.Read(auth); err != nil {
+ return nil, err
+ }
+ if _, err := rand.Read(crypt); err != nil {
+ return nil, err
+ }
+
+ rid, err := id.NewULID()
+ if err != nil {
+ return nil, err
+ }
+
+ rs := >smodel.RouterSession{
+ ID: rid,
+ Auth: auth,
+ Crypt: crypt,
+ }
+
+ q := s.conn.
+ NewInsert().
+ Model(rs)
+
+ _, err = q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return rs, err
+}
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
new file mode 100644
index 000000000..da8d8ca41
--- /dev/null
+++ b/internal/db/bundb/status.go
@@ -0,0 +1,375 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "container/list"
+ "context"
+ "errors"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type statusDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+ cache cache.Cache
+}
+
+func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
+ if s.cache == nil {
+ s.cache = cache.New()
+ }
+
+ if err := s.cache.Store(id, status); err != nil {
+ s.log.Panicf("statusDB: error storing in cache: %s", err)
+ }
+}
+
+func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
+ if s.cache == nil {
+ s.cache = cache.New()
+ return nil, false
+ }
+
+ sI, err := s.cache.Fetch(id)
+ if err != nil || sI == nil {
+ return nil, false
+ }
+
+ status, ok := sI.(*gtsmodel.Status)
+ if !ok {
+ s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
+ }
+
+ return status, true
+}
+
+func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
+ return s.conn.
+ NewSelect().
+ Model(status).
+ Relation("Attachments").
+ Relation("Tags").
+ Relation("Mentions").
+ Relation("Emojis").
+ Relation("Account").
+ Relation("InReplyToAccount").
+ Relation("BoostOfAccount").
+ Relation("CreatedWithApplication")
+}
+
+func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
+ if status.InReplyToID != "" && status.InReplyTo == nil {
+ if inReplyTo, cached := s.statusCached(status.InReplyToID); cached {
+ status.InReplyTo = inReplyTo
+ } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
+ status.InReplyTo = inReplyTo
+ }
+ }
+
+ if status.BoostOfID != "" && status.BoostOf == nil {
+ if boostOf, cached := s.statusCached(status.BoostOfID); cached {
+ status.BoostOf = boostOf
+ } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
+ status.BoostOf = boostOf
+ }
+ }
+
+ return status
+}
+
+func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
+ return s.conn.
+ NewSelect().
+ Model(faves).
+ Relation("Account").
+ Relation("TargetAccount").
+ Relation("Status")
+}
+
+func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(id); cached {
+ return status, nil
+ }
+
+ status := new(gtsmodel.Status)
+
+ q := s.newStatusQ(status).
+ Where("status.id = ?", id)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(id, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(uri); cached {
+ return status, nil
+ }
+
+ status := >smodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("LOWER(status.uri) = LOWER(?)", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(uri, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.statusCached(uri); cached {
+ return status, nil
+ }
+
+ status := >smodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("LOWER(status.url) = LOWER(?)", uri)
+
+ err := processErrorResponse(q.Scan(ctx))
+
+ if err != nil {
+ return nil, err
+ }
+
+ if status != nil {
+ s.cacheStatus(uri, status)
+ }
+
+ return s.getAttachedStatuses(ctx, status), err
+}
+
+func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
+ transaction := func(ctx context.Context, tx bun.Tx) error {
+ // create links between this status and any emojis it uses
+ for _, i := range status.EmojiIDs {
+ if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{
+ StatusID: status.ID,
+ EmojiID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // create links between this status and any tags it uses
+ for _, i := range status.TagIDs {
+ if _, err := tx.NewInsert().Model(>smodel.StatusToTag{
+ StatusID: status.ID,
+ TagID: i,
+ }).Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ // change the status ID of the media attachments to the new status
+ for _, a := range status.Attachments {
+ a.StatusID = status.ID
+ a.UpdatedAt = time.Now()
+ if _, err := s.conn.NewUpdate().Model(a).
+ Where("id = ?", a.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ _, err := tx.NewInsert().Model(status).Exec(ctx)
+ return err
+ }
+
+ return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
+}
+
+func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
+ parents := []*gtsmodel.Status{}
+ s.statusParent(ctx, status, &parents, onlyDirect)
+
+ return parents, nil
+}
+
+func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
+ if status.InReplyToID == "" {
+ return
+ }
+
+ parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID)
+ if err == nil {
+ *foundStatuses = append(*foundStatuses, parentStatus)
+ }
+
+ if onlyDirect {
+ return
+ }
+
+ s.statusParent(ctx, parentStatus, foundStatuses, false)
+}
+
+func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
+ foundStatuses := &list.List{}
+ foundStatuses.PushFront(status)
+ s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID)
+
+ children := []*gtsmodel.Status{}
+ for e := foundStatuses.Front(); e != nil; e = e.Next() {
+ entry, ok := e.Value.(*gtsmodel.Status)
+ if !ok {
+ panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
+ }
+
+ // only append children, not the overall parent status
+ if entry.ID != status.ID {
+ children = append(children, entry)
+ }
+ }
+
+ return children, nil
+}
+
+func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
+ immediateChildren := []*gtsmodel.Status{}
+
+ q := s.conn.
+ NewSelect().
+ Model(&immediateChildren).
+ Where("in_reply_to_id = ?", status.ID)
+ if minID != "" {
+ q = q.Where("status.id > ?", minID)
+ }
+
+ if err := q.Scan(ctx); err != nil {
+ return
+ }
+
+ for _, child := range immediateChildren {
+ insertLoop:
+ for e := foundStatuses.Front(); e != nil; e = e.Next() {
+ entry, ok := e.Value.(*gtsmodel.Status)
+ if !ok {
+ panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
+ }
+
+ if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
+ foundStatuses.InsertAfter(child, e)
+ break insertLoop
+ }
+ }
+
+ // only do one loop if we only want direct children
+ if onlyDirect {
+ return
+ }
+ s.statusChildren(ctx, child, foundStatuses, false, minID)
+ }
+}
+
+func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
+ return s.conn.NewSelect().Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx)
+}
+
+func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(>smodel.StatusFave{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(>smodel.Status{}).
+ Where("boost_of_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(>smodel.StatusMute{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
+ q := s.conn.
+ NewSelect().
+ Model(>smodel.StatusBookmark{}).
+ Where("status_id = ?", status.ID).
+ Where("account_id = ?", accountID)
+
+ return exists(ctx, q)
+}
+
+func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
+ faves := []*gtsmodel.StatusFave{}
+
+ q := s.newFaveQ(&faves).
+ Where("status_id = ?", status.ID)
+
+ err := processErrorResponse(q.Scan(ctx))
+ return faves, err
+}
+
+func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
+ reblogs := []*gtsmodel.Status{}
+
+ q := s.newStatusQ(&reblogs).
+ Where("boost_of_id = ?", status.ID)
+
+ err := processErrorResponse(q.Scan(ctx))
+ return reblogs, err
+}
diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go
new file mode 100644
index 000000000..513000577
--- /dev/null
+++ b/internal/db/bundb/status_test.go
@@ -0,0 +1,136 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type StatusTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *StatusTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *StatusTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *StatusTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *StatusTestSuite) TestGetStatusByID() {
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID)
+ if err != nil {
+ fmt.Println(err.Error())
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.Nil(status.BoostOf)
+ suite.Nil(status.BoostOfAccount)
+ suite.Nil(status.InReplyTo)
+ suite.Nil(status.InReplyToAccount)
+}
+
+func (suite *StatusTestSuite) TestGetStatusByURI() {
+ status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.Nil(status.BoostOf)
+ suite.Nil(status.BoostOfAccount)
+ suite.Nil(status.InReplyTo)
+ suite.Nil(status.InReplyToAccount)
+}
+
+func (suite *StatusTestSuite) TestGetStatusWithExtras() {
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["admin_account_status_1"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.NotEmpty(status.Tags)
+ suite.NotEmpty(status.Attachments)
+ suite.NotEmpty(status.Emojis)
+}
+
+func (suite *StatusTestSuite) TestGetStatusWithMention() {
+ status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_2_status_5"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.NotEmpty(status.Mentions)
+ suite.NotEmpty(status.MentionIDs)
+ suite.NotNil(status.InReplyTo)
+ suite.NotNil(status.InReplyToAccount)
+}
+
+func (suite *StatusTestSuite) TestGetStatusTwice() {
+ before1 := time.Now()
+ _, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
+ suite.NoError(err)
+ after1 := time.Now()
+ duration1 := after1.Sub(before1)
+ fmt.Println(duration1.Milliseconds())
+
+ before2 := time.Now()
+ _, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
+ suite.NoError(err)
+ after2 := time.Now()
+ duration2 := after2.Sub(before2)
+ fmt.Println(duration2.Milliseconds())
+
+ // second retrieval should be several orders faster since it will be cached now
+ suite.Less(duration2, duration1)
+}
+
+func TestStatusTestSuite(t *testing.T) {
+ suite.Run(t, new(StatusTestSuite))
+}
diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go
new file mode 100644
index 000000000..b62ad4c50
--- /dev/null
+++ b/internal/db/bundb/timeline.go
@@ -0,0 +1,199 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+ "database/sql"
+ "sort"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type timelineDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
+ statuses := []*gtsmodel.Status{}
+ q := t.conn.
+ NewSelect().
+ Model(&statuses)
+
+ q = q.ColumnExpr("status.*").
+ // Find out who accountID follows.
+ Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id").
+ // Sort by highest ID (newest) to lowest ID (oldest)
+ Order("status.id DESC")
+
+ if maxID != "" {
+ // return only statuses LOWER (ie., older) than maxID
+ q = q.Where("status.id < ?", maxID)
+ }
+
+ if sinceID != "" {
+ // return only statuses HIGHER (ie., newer) than sinceID
+ q = q.Where("status.id > ?", sinceID)
+ }
+
+ if minID != "" {
+ // return only statuses HIGHER (ie., newer) than minID
+ q = q.Where("status.id > ?", minID)
+ }
+
+ if local {
+ // return only statuses posted by local account havers
+ q = q.Where("status.local = ?", local)
+ }
+
+ if limit > 0 {
+ // limit amount of statuses returned
+ q = q.Limit(limit)
+ }
+
+ // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
+ // OR statuses posted by accountID itself (since a user should be able to see their own statuses).
+ //
+ // This is equivalent to something like WHERE ... AND (... OR ...)
+ // See: https://bun.uptrace.dev/guide/queries.html#select
+ whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
+ return q.
+ WhereOr("f.account_id = ?", accountID).
+ WhereOr("status.account_id = ?", accountID)
+ }
+
+ q = q.WhereGroup(" AND ", whereGroup)
+
+ return statuses, processErrorResponse(q.Scan(ctx))
+}
+
+func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
+ statuses := []*gtsmodel.Status{}
+
+ q := t.conn.
+ NewSelect().
+ Model(&statuses).
+ Where("visibility = ?", gtsmodel.VisibilityPublic).
+ Where("? IS NULL", bun.Ident("in_reply_to_id")).
+ Where("? IS NULL", bun.Ident("in_reply_to_uri")).
+ Where("? IS NULL", bun.Ident("boost_of_id")).
+ Order("status.id DESC")
+
+ if maxID != "" {
+ q = q.Where("status.id < ?", maxID)
+ }
+
+ if sinceID != "" {
+ q = q.Where("status.id > ?", sinceID)
+ }
+
+ if minID != "" {
+ q = q.Where("status.id > ?", minID)
+ }
+
+ if local {
+ q = q.Where("status.local = ?", local)
+ }
+
+ if limit > 0 {
+ q = q.Limit(limit)
+ }
+
+ return statuses, processErrorResponse(q.Scan(ctx))
+}
+
+// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
+// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
+func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
+
+ faves := []*gtsmodel.StatusFave{}
+
+ fq := t.conn.
+ NewSelect().
+ Model(&faves).
+ Where("account_id = ?", accountID).
+ Order("id DESC")
+
+ if maxID != "" {
+ fq = fq.Where("id < ?", maxID)
+ }
+
+ if minID != "" {
+ fq = fq.Where("id > ?", minID)
+ }
+
+ if limit > 0 {
+ fq = fq.Limit(limit)
+ }
+
+ err := fq.Scan(ctx)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, "", "", db.ErrNoEntries
+ }
+ return nil, "", "", err
+ }
+
+ if len(faves) == 0 {
+ return nil, "", "", db.ErrNoEntries
+ }
+
+ // map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
+ statusesFavesMap := map[string]string{}
+
+ in := []string{}
+ for _, f := range faves {
+ statusesFavesMap[f.StatusID] = f.ID
+ in = append(in, f.StatusID)
+ }
+
+ statuses := []*gtsmodel.Status{}
+ err = t.conn.
+ NewSelect().
+ Model(&statuses).
+ Where("id IN (?)", bun.In(in)).
+ Scan(ctx)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, "", "", db.ErrNoEntries
+ }
+ return nil, "", "", err
+ }
+
+ if len(statuses) == 0 {
+ return nil, "", "", db.ErrNoEntries
+ }
+
+ // arrange statuses by fave ID
+ sort.Slice(statuses, func(i int, j int) bool {
+ statusI := statuses[i]
+ statusJ := statuses[j]
+ return statusesFavesMap[statusI.ID] < statusesFavesMap[statusJ.ID]
+ })
+
+ nextMaxID := faves[len(faves)-1].ID
+ prevMinID := faves[0].ID
+ return statuses, nextMaxID, prevMinID, nil
+}
diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go
new file mode 100644
index 000000000..115d18de2
--- /dev/null
+++ b/internal/db/bundb/util.go
@@ -0,0 +1,78 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+ "strings"
+
+ "database/sql"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/uptrace/bun"
+)
+
+// processErrorResponse parses the given error and returns an appropriate DBError.
+func processErrorResponse(err error) db.Error {
+ switch err {
+ case nil:
+ return nil
+ case sql.ErrNoRows:
+ return db.ErrNoEntries
+ default:
+ if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists
+ }
+ return err
+ }
+}
+
+func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
+ count, err := q.Count(ctx)
+
+ exists := count != 0
+
+ err = processErrorResponse(err)
+
+ if err != nil {
+ if err == db.ErrNoEntries {
+ return false, nil
+ }
+ return false, err
+ }
+
+ return exists, nil
+}
+
+func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
+ count, err := q.Count(ctx)
+
+ notExists := count == 0
+
+ err = processErrorResponse(err)
+
+ if err != nil {
+ if err == db.ErrNoEntries {
+ return true, nil
+ }
+ return false, err
+ }
+
+ return notExists, nil
+}
diff --git a/internal/db/db.go b/internal/db/db.go
index d6ac883e4..ec94fcfe7 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -19,6 +19,8 @@
package db
import (
+ "context"
+
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -38,6 +40,7 @@ type DB interface {
Mention
Notification
Relationship
+ Session
Status
Timeline
@@ -52,7 +55,7 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the accounts in the DB, it's just for checking
// if they exist in the db and conveniently returning them if they do.
- MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error)
+ MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error)
// TagStringsToTags takes a slice of deduplicated, lowercase tags in the form "somehashtag", which have been
// used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then
@@ -61,7 +64,7 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the tags in the DB, it's just for checking
// if they exist in the db already, and conveniently returning them, or creating new tag structs.
- TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error)
+ TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error)
// EmojiStringsToEmojis takes a slice of deduplicated, lowercase emojis in the form ":emojiname:", which have been
// used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then
@@ -69,5 +72,5 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the emoji in the DB, it's just for checking
// if they exist in the db and conveniently returning them if they do.
- EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error)
+ EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error)
}
diff --git a/internal/db/domain.go b/internal/db/domain.go
index a6583c80c..df50a6770 100644
--- a/internal/db/domain.go
+++ b/internal/db/domain.go
@@ -18,19 +18,22 @@
package db
-import "net/url"
+import (
+ "context"
+ "net/url"
+)
// Domain contains DB functions related to domains and domain blocks.
type Domain interface {
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
- IsDomainBlocked(domain string) (bool, Error)
+ IsDomainBlocked(ctx context.Context, domain string) (bool, Error)
// AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
- AreDomainsBlocked(domains []string) (bool, Error)
+ AreDomainsBlocked(ctx context.Context, domains []string) (bool, Error)
// IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
- IsURIBlocked(uri *url.URL) (bool, Error)
+ IsURIBlocked(ctx context.Context, uri *url.URL) (bool, Error)
// AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
- AreURIsBlocked(uris []*url.URL) (bool, Error)
+ AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, Error)
}
diff --git a/internal/db/instance.go b/internal/db/instance.go
index 1f7c83e4f..dcd978a81 100644
--- a/internal/db/instance.go
+++ b/internal/db/instance.go
@@ -18,19 +18,23 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Instance contains functions for instance-level actions (counting instance users etc.).
type Instance interface {
// CountInstanceUsers returns the number of known accounts registered with the given domain.
- CountInstanceUsers(domain string) (int, Error)
+ CountInstanceUsers(ctx context.Context, domain string) (int, Error)
// CountInstanceStatuses returns the number of known statuses posted from the given domain.
- CountInstanceStatuses(domain string) (int, Error)
+ CountInstanceStatuses(ctx context.Context, domain string) (int, Error)
// CountInstanceDomains returns the number of known instances known that the given domain federates with.
- CountInstanceDomains(domain string) (int, Error)
+ CountInstanceDomains(ctx context.Context, domain string) (int, Error)
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
- GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
+ GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
}
diff --git a/internal/db/media.go b/internal/db/media.go
index db4db3411..b779dd276 100644
--- a/internal/db/media.go
+++ b/internal/db/media.go
@@ -18,10 +18,14 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Media contains functions related to creating/getting/removing media attachments.
type Media interface {
// GetAttachmentByID gets a single attachment by its ID
- GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error)
+ GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error)
}
diff --git a/internal/db/mention.go b/internal/db/mention.go
index cb1c56dc1..b9b45546a 100644
--- a/internal/db/mention.go
+++ b/internal/db/mention.go
@@ -18,13 +18,17 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Mention contains functions for getting/creating mentions in the database.
type Mention interface {
// GetMention gets a single mention by ID
- GetMention(id string) (*gtsmodel.Mention, Error)
+ GetMention(ctx context.Context, id string) (*gtsmodel.Mention, Error)
// GetMentions gets multiple mentions.
- GetMentions(ids []string) ([]*gtsmodel.Mention, Error)
+ GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error)
}
diff --git a/internal/db/notification.go b/internal/db/notification.go
index 326f0f149..09c17f031 100644
--- a/internal/db/notification.go
+++ b/internal/db/notification.go
@@ -18,14 +18,18 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Notification contains functions for creating and getting notifications.
type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
- GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
+ GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
// GetNotification returns one notification according to its id.
- GetNotification(id string) (*gtsmodel.Notification, Error)
+ GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error)
}
diff --git a/internal/db/pg/account.go b/internal/db/pg/account.go
deleted file mode 100644
index 3889c6601..000000000
--- a/internal/db/pg/account.go
+++ /dev/null
@@ -1,256 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "context"
- "errors"
- "fmt"
- "time"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-type accountDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
-}
-
-func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
- return a.conn.Model(account).
- Relation("AvatarMediaAttachment").
- Relation("HeaderMediaAttachment")
-}
-
-func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) {
- account := >smodel.Account{}
-
- q := a.newAccountQ(account).
- Where("account.id = ?", id)
-
- err := processErrorResponse(q.Select())
-
- return account, err
-}
-
-func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) {
- account := >smodel.Account{}
-
- q := a.newAccountQ(account).
- Where("account.uri = ?", uri)
-
- err := processErrorResponse(q.Select())
-
- return account, err
-}
-
-func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) {
- account := >smodel.Account{}
-
- q := a.newAccountQ(account).
- Where("account.url = ?", uri)
-
- err := processErrorResponse(q.Select())
-
- return account, err
-}
-
-func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) {
- account := >smodel.Account{}
-
- q := a.newAccountQ(account)
-
- if domain == "" {
- q = q.
- Where("account.username = ?", domain).
- Where("account.domain = ?", domain)
- } else {
- q = q.
- Where("account.username = ?", domain).
- Where("? IS NULL", pg.Ident("domain"))
- }
-
- err := processErrorResponse(q.Select())
-
- return account, err
-}
-
-func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) {
- status := >smodel.Status{}
-
- q := a.conn.Model(status).
- Order("id DESC").
- Limit(1).
- Where("account_id = ?", accountID).
- Column("created_at")
-
- err := processErrorResponse(q.Select())
-
- return status.CreatedAt, err
-}
-
-func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
- if mediaAttachment.Avatar && mediaAttachment.Header {
- return errors.New("one media attachment cannot be both header and avatar")
- }
-
- var headerOrAVI string
- if mediaAttachment.Avatar {
- headerOrAVI = "avatar"
- } else if mediaAttachment.Header {
- headerOrAVI = "header"
- } else {
- return errors.New("given media attachment was neither a header nor an avatar")
- }
-
- // TODO: there are probably more side effects here that need to be handled
- if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
- return err
- }
-
- if _, err := a.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
- return err
- }
- return nil
-}
-
-func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) {
- account := >smodel.Account{}
-
- q := a.newAccountQ(account).
- Where("username = ?", username).
- Where("? IS NULL", pg.Ident("domain"))
-
- err := processErrorResponse(q.Select())
-
- return account, err
-}
-
-func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) {
- faves := []*gtsmodel.StatusFave{}
-
- if err := a.conn.Model(&faves).
- Where("account_id = ?", accountID).
- Select(); err != nil {
- if err == pg.ErrNoRows {
- return faves, nil
- }
- return nil, err
- }
- return faves, nil
-}
-
-func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) {
- return a.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count()
-}
-
-func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
- a.log.Debugf("getting statuses for account %s", accountID)
- statuses := []*gtsmodel.Status{}
-
- q := a.conn.Model(&statuses).Order("id DESC")
- if accountID != "" {
- q = q.Where("account_id = ?", accountID)
- }
-
- if limit != 0 {
- q = q.Limit(limit)
- }
-
- if excludeReplies {
- q = q.Where("? IS NULL", pg.Ident("in_reply_to_id"))
- }
-
- if pinnedOnly {
- q = q.Where("pinned = ?", true)
- }
-
- if mediaOnly {
- q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) {
- return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil
- })
- }
-
- if maxID != "" {
- q = q.Where("id < ?", maxID)
- }
-
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
- }
-
- if len(statuses) == 0 {
- return nil, db.ErrNoEntries
- }
-
- a.log.Debugf("returning statuses for account %s", accountID)
- return statuses, nil
-}
-
-func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
- blocks := []*gtsmodel.Block{}
-
- fq := a.conn.Model(&blocks).
- Where("block.account_id = ?", accountID).
- Relation("TargetAccount").
- Order("block.id DESC")
-
- if maxID != "" {
- fq = fq.Where("block.id < ?", maxID)
- }
-
- if sinceID != "" {
- fq = fq.Where("block.id > ?", sinceID)
- }
-
- if limit > 0 {
- fq = fq.Limit(limit)
- }
-
- err := fq.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries
- }
- return nil, "", "", err
- }
-
- if len(blocks) == 0 {
- return nil, "", "", db.ErrNoEntries
- }
-
- accounts := []*gtsmodel.Account{}
- for _, b := range blocks {
- accounts = append(accounts, b.TargetAccount)
- }
-
- nextMaxID := blocks[len(blocks)-1].ID
- prevMinID := blocks[0].ID
- return accounts, nextMaxID, prevMinID, nil
-}
diff --git a/internal/db/pg/account_test.go b/internal/db/pg/account_test.go
deleted file mode 100644
index 7ea5ff39a..000000000
--- a/internal/db/pg/account_test.go
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg_test
-
-import (
- "testing"
-
- "github.com/stretchr/testify/suite"
- "github.com/superseriousbusiness/gotosocial/testrig"
-)
-
-type AccountTestSuite struct {
- PGStandardTestSuite
-}
-
-func (suite *AccountTestSuite) SetupSuite() {
- suite.testTokens = testrig.NewTestTokens()
- suite.testClients = testrig.NewTestClients()
- suite.testApplications = testrig.NewTestApplications()
- suite.testUsers = testrig.NewTestUsers()
- suite.testAccounts = testrig.NewTestAccounts()
- suite.testAttachments = testrig.NewTestAttachments()
- suite.testStatuses = testrig.NewTestStatuses()
- suite.testTags = testrig.NewTestTags()
- suite.testMentions = testrig.NewTestMentions()
-}
-
-func (suite *AccountTestSuite) SetupTest() {
- suite.config = testrig.NewTestConfig()
- suite.db = testrig.NewTestDB()
- suite.log = testrig.NewTestLog()
-
- testrig.StandardDBSetup(suite.db, suite.testAccounts)
-}
-
-func (suite *AccountTestSuite) TearDownTest() {
- testrig.StandardDBTeardown(suite.db)
-}
-
-func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
- account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID)
- if err != nil {
- suite.FailNow(err.Error())
- }
- suite.NotNil(account)
- suite.NotNil(account.AvatarMediaAttachment)
- suite.NotEmpty(account.AvatarMediaAttachment.URL)
- suite.NotNil(account.HeaderMediaAttachment)
- suite.NotEmpty(account.HeaderMediaAttachment.URL)
-}
-
-func TestAccountTestSuite(t *testing.T) {
- suite.Run(t, new(AccountTestSuite))
-}
diff --git a/internal/db/pg/admin.go b/internal/db/pg/admin.go
deleted file mode 100644
index 854f56ef0..000000000
--- a/internal/db/pg/admin.go
+++ /dev/null
@@ -1,235 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "context"
- "crypto/rand"
- "crypto/rsa"
- "fmt"
- "net"
- "net/mail"
- "strings"
- "time"
-
- "github.com/go-pg/pg/v10"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/id"
- "github.com/superseriousbusiness/gotosocial/internal/util"
- "golang.org/x/crypto/bcrypt"
-)
-
-type adminDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
-}
-
-func (a *adminDB) IsUsernameAvailable(username string) db.Error {
- // if no error we fail because it means we found something
- // if error but it's not pg.ErrNoRows then we fail
- // if err is pg.ErrNoRows we're good, we found nothing so continue
- if err := a.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
- return fmt.Errorf("username %s already in use", username)
- } else if err != pg.ErrNoRows {
- return fmt.Errorf("db error: %s", err)
- }
- return nil
-}
-
-func (a *adminDB) IsEmailAvailable(email string) db.Error {
- // parse the domain from the email
- m, err := mail.ParseAddress(email)
- if err != nil {
- return fmt.Errorf("error parsing email address %s: %s", email, err)
- }
- domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
-
- // check if the email domain is blocked
- if err := a.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
- // fail because we found something
- return fmt.Errorf("email domain %s is blocked", domain)
- } else if err != pg.ErrNoRows {
- // fail because we got an unexpected error
- return fmt.Errorf("db error: %s", err)
- }
-
- // check if this email is associated with a user already
- if err := a.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
- // fail because we found something
- return fmt.Errorf("email %s already in use", email)
- } else if err != pg.ErrNoRows {
- // fail because we got an unexpected error
- return fmt.Errorf("db error: %s", err)
- }
- return nil
-}
-
-func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
- key, err := rsa.GenerateKey(rand.Reader, 2048)
- if err != nil {
- a.log.Errorf("error creating new rsa key: %s", err)
- return nil, err
- }
-
- // if something went wrong while creating a user, we might already have an account, so check here first...
- acct := >smodel.Account{}
- err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
- if err != nil {
- // there's been an actual error
- if err != pg.ErrNoRows {
- return nil, fmt.Errorf("db error checking existence of account: %s", err)
- }
-
- // we just don't have an account yet create one
- newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
- newAccountID, err := id.NewRandomULID()
- if err != nil {
- return nil, err
- }
-
- acct = >smodel.Account{
- ID: newAccountID,
- Username: username,
- DisplayName: username,
- Reason: reason,
- URL: newAccountURIs.UserURL,
- PrivateKey: key,
- PublicKey: &key.PublicKey,
- PublicKeyURI: newAccountURIs.PublicKeyURI,
- ActorType: gtsmodel.ActivityStreamsPerson,
- URI: newAccountURIs.UserURI,
- InboxURI: newAccountURIs.InboxURI,
- OutboxURI: newAccountURIs.OutboxURI,
- FollowersURI: newAccountURIs.FollowersURI,
- FollowingURI: newAccountURIs.FollowingURI,
- FeaturedCollectionURI: newAccountURIs.CollectionURI,
- }
- if _, err = a.conn.Model(acct).Insert(); err != nil {
- return nil, err
- }
- }
-
- pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
- if err != nil {
- return nil, fmt.Errorf("error hashing password: %s", err)
- }
-
- newUserID, err := id.NewRandomULID()
- if err != nil {
- return nil, err
- }
-
- u := >smodel.User{
- ID: newUserID,
- AccountID: acct.ID,
- EncryptedPassword: string(pw),
- SignUpIP: signUpIP.To4(),
- Locale: locale,
- UnconfirmedEmail: email,
- CreatedByApplicationID: appID,
- Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
- }
-
- if emailVerified {
- u.ConfirmedAt = time.Now()
- u.Email = email
- }
-
- if admin {
- u.Admin = true
- u.Moderator = true
- }
-
- if _, err = a.conn.Model(u).Insert(); err != nil {
- return nil, err
- }
-
- return u, nil
-}
-
-func (a *adminDB) CreateInstanceAccount() db.Error {
- username := a.config.Host
- key, err := rsa.GenerateKey(rand.Reader, 2048)
- if err != nil {
- a.log.Errorf("error creating new rsa key: %s", err)
- return err
- }
-
- aID, err := id.NewRandomULID()
- if err != nil {
- return err
- }
-
- newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
- acct := >smodel.Account{
- ID: aID,
- Username: a.config.Host,
- DisplayName: username,
- URL: newAccountURIs.UserURL,
- PrivateKey: key,
- PublicKey: &key.PublicKey,
- PublicKeyURI: newAccountURIs.PublicKeyURI,
- ActorType: gtsmodel.ActivityStreamsPerson,
- URI: newAccountURIs.UserURI,
- InboxURI: newAccountURIs.InboxURI,
- OutboxURI: newAccountURIs.OutboxURI,
- FollowersURI: newAccountURIs.FollowersURI,
- FollowingURI: newAccountURIs.FollowingURI,
- FeaturedCollectionURI: newAccountURIs.CollectionURI,
- }
- inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert()
- if err != nil {
- return err
- }
- if inserted {
- a.log.Infof("created instance account %s with id %s", username, acct.ID)
- } else {
- a.log.Infof("instance account %s already exists with id %s", username, acct.ID)
- }
- return nil
-}
-
-func (a *adminDB) CreateInstanceInstance() db.Error {
- iID, err := id.NewRandomULID()
- if err != nil {
- return err
- }
-
- i := >smodel.Instance{
- ID: iID,
- Domain: a.config.Host,
- Title: a.config.Host,
- URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
- }
- inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert()
- if err != nil {
- return err
- }
- if inserted {
- a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID)
- } else {
- a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID)
- }
- return nil
-}
diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go
deleted file mode 100644
index 6e76b4450..000000000
--- a/internal/db/pg/basic.go
+++ /dev/null
@@ -1,205 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "context"
- "errors"
- "fmt"
- "strings"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
-)
-
-type basicDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
-}
-
-func (b *basicDB) Put(i interface{}) db.Error {
- _, err := b.conn.Model(i).Insert(i)
- if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
- return db.ErrAlreadyExists
- }
- return err
-}
-
-func (b *basicDB) GetByID(id string, i interface{}) db.Error {
- if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
-
- }
- return nil
-}
-
-func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error {
- if len(where) == 0 {
- return errors.New("no queries provided")
- }
-
- q := b.conn.Model(i)
- for _, w := range where {
-
- if w.Value == nil {
- q = q.Where("? IS NULL", pg.Ident(w.Key))
- } else {
- if w.CaseInsensitive {
- q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
- } else {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
- }
- }
-
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) GetAll(i interface{}) db.Error {
- if err := b.conn.Model(i).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) DeleteByID(id string, i interface{}) db.Error {
- if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
- // if there are no rows *anyway* then that's fine
- // just return err if there's an actual error
- if err != pg.ErrNoRows {
- return err
- }
- }
- return nil
-}
-
-func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error {
- if len(where) == 0 {
- return errors.New("no queries provided")
- }
-
- q := b.conn.Model(i)
- for _, w := range where {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
-
- if _, err := q.Delete(); err != nil {
- // if there are no rows *anyway* then that's fine
- // just return err if there's an actual error
- if err != pg.ErrNoRows {
- return err
- }
- }
- return nil
-}
-
-func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error {
- if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) UpdateByID(id string, i interface{}) db.Error {
- if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries
- }
- return err
- }
- return nil
-}
-
-func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error {
- _, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
- return err
-}
-
-func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error {
- q := b.conn.Model(i)
-
- for _, w := range where {
- if w.Value == nil {
- q = q.Where("? IS NULL", pg.Ident(w.Key))
- } else {
- if w.CaseInsensitive {
- q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
- } else {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
- }
- }
-
- q = q.Set("? = ?", pg.Safe(key), value)
-
- _, err := q.Update()
-
- return err
-}
-
-func (b *basicDB) CreateTable(i interface{}) db.Error {
- return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- })
-}
-
-func (b *basicDB) DropTable(i interface{}) db.Error {
- return b.conn.Model(i).DropTable(&orm.DropTableOptions{
- IfExists: true,
- })
-}
-
-func (b *basicDB) RegisterTable(i interface{}) db.Error {
- orm.RegisterTable(i)
- return nil
-}
-
-func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
- return b.conn.Ping(ctx)
-}
-
-func (b *basicDB) Stop(ctx context.Context) db.Error {
- b.log.Info("closing db connection")
- if err := b.conn.Close(); err != nil {
- // only cancel if there's a problem closing the db
- b.cancel()
- return err
- }
- return nil
-}
diff --git a/internal/db/pg/domain.go b/internal/db/pg/domain.go
deleted file mode 100644
index 4e9b2ab48..000000000
--- a/internal/db/pg/domain.go
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "context"
- "net/url"
-
- "github.com/go-pg/pg/v10"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/util"
-)
-
-type domainDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
-}
-
-func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) {
- if domain == "" {
- return false, nil
- }
-
- blocked, err := d.conn.
- Model(>smodel.DomainBlock{}).
- Where("LOWER(domain) = LOWER(?)", domain).
- Exists()
-
- err = processErrorResponse(err)
-
- return blocked, err
-}
-
-func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) {
- // filter out any doubles
- uniqueDomains := util.UniqueStrings(domains)
-
- for _, domain := range uniqueDomains {
- if blocked, err := d.IsDomainBlocked(domain); err != nil {
- return false, err
- } else if blocked {
- return blocked, nil
- }
- }
-
- // no blocks found
- return false, nil
-}
-
-func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) {
- domain := uri.Hostname()
- return d.IsDomainBlocked(domain)
-}
-
-func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) {
- domains := []string{}
- for _, uri := range uris {
- domains = append(domains, uri.Hostname())
- }
-
- return d.AreDomainsBlocked(domains)
-}
diff --git a/internal/db/pg/instance.go b/internal/db/pg/instance.go
deleted file mode 100644
index 968832ca5..000000000
--- a/internal/db/pg/instance.go
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "context"
-
- "github.com/go-pg/pg/v10"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-type instanceDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
-}
-
-func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) {
- q := i.conn.Model(&[]*gtsmodel.Account{})
-
- if domain == i.config.Host {
- // if the domain is *this* domain, just count where the domain field is null
- q = q.Where("? IS NULL", pg.Ident("domain"))
- } else {
- q = q.Where("domain = ?", domain)
- }
-
- // don't count the instance account or suspended users
- q = q.Where("username != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
-
- return q.Count()
-}
-
-func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) {
- q := i.conn.Model(&[]*gtsmodel.Status{})
-
- if domain == i.config.Host {
- // if the domain is *this* domain, just count where local is true
- q = q.Where("local = ?", true)
- } else {
- // join on the domain of the account
- q = q.Join("JOIN accounts AS account ON account.id = status.account_id").
- Where("account.domain = ?", domain)
- }
-
- return q.Count()
-}
-
-func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) {
- q := i.conn.Model(&[]*gtsmodel.Instance{})
-
- if domain == i.config.Host {
- // if the domain is *this* domain, just count other instances it knows about
- // exclude domains that are blocked
- q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
- } else {
- // TODO: implement federated domain counting properly for remote domains
- return 0, nil
- }
-
- return q.Count()
-}
-
-func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
- i.log.Debug("GetAccountsForInstance")
-
- accounts := []*gtsmodel.Account{}
-
- q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
-
- if maxID != "" {
- q = q.Where("id < ?", maxID)
- }
-
- if limit > 0 {
- q = q.Limit(limit)
- }
-
- err := q.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
- }
-
- if len(accounts) == 0 {
- return nil, db.ErrNoEntries
- }
-
- return accounts, nil
-}
diff --git a/internal/db/pg/media.go b/internal/db/pg/media.go
deleted file mode 100644
index 618030af3..000000000
--- a/internal/db/pg/media.go
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "context"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-type mediaDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
-}
-
-func (m *mediaDB) newMediaQ(i interface{}) *orm.Query {
- return m.conn.Model(i).
- Relation("Account")
-}
-
-func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) {
- attachment := >smodel.MediaAttachment{}
-
- q := m.newMediaQ(attachment).
- Where("media_attachment.id = ?", id)
-
- err := processErrorResponse(q.Select())
-
- return attachment, err
-}
diff --git a/internal/db/pg/mention.go b/internal/db/pg/mention.go
deleted file mode 100644
index b31f07b67..000000000
--- a/internal/db/pg/mention.go
+++ /dev/null
@@ -1,108 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "context"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-type mentionDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
- cache cache.Cache
-}
-
-func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) {
- if m.cache == nil {
- m.cache = cache.New()
- }
-
- if err := m.cache.Store(id, mention); err != nil {
- m.log.Panicf("mentionDB: error storing in cache: %s", err)
- }
-}
-
-func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
- if m.cache == nil {
- m.cache = cache.New()
- return nil, false
- }
-
- mI, err := m.cache.Fetch(id)
- if err != nil || mI == nil {
- return nil, false
- }
-
- mention, ok := mI.(*gtsmodel.Mention)
- if !ok {
- m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id)
- }
-
- return mention, true
-}
-
-func (m *mentionDB) newMentionQ(i interface{}) *orm.Query {
- return m.conn.Model(i).
- Relation("Status").
- Relation("OriginAccount").
- Relation("TargetAccount")
-}
-
-func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
- if mention, cached := m.mentionCached(id); cached {
- return mention, nil
- }
-
- mention := >smodel.Mention{}
-
- q := m.newMentionQ(mention).
- Where("mention.id = ?", id)
-
- err := processErrorResponse(q.Select())
-
- if err == nil && mention != nil {
- m.cacheMention(id, mention)
- }
-
- return mention, err
-}
-
-func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) {
- mentions := []*gtsmodel.Mention{}
-
- for _, i := range ids {
- mention, err := m.GetMention(i)
- if err != nil {
- return nil, processErrorResponse(err)
- }
- mentions = append(mentions, mention)
- }
-
- return mentions, nil
-}
diff --git a/internal/db/pg/notification.go b/internal/db/pg/notification.go
deleted file mode 100644
index 281a76d85..000000000
--- a/internal/db/pg/notification.go
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "context"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-type notificationDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
- cache cache.Cache
-}
-
-func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) {
- if n.cache == nil {
- n.cache = cache.New()
- }
-
- if err := n.cache.Store(id, notification); err != nil {
- n.log.Panicf("notificationDB: error storing in cache: %s", err)
- }
-}
-
-func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) {
- if n.cache == nil {
- n.cache = cache.New()
- return nil, false
- }
-
- nI, err := n.cache.Fetch(id)
- if err != nil || nI == nil {
- return nil, false
- }
-
- notification, ok := nI.(*gtsmodel.Notification)
- if !ok {
- n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id)
- }
-
- return notification, true
-}
-
-func (n *notificationDB) newNotificationQ(i interface{}) *orm.Query {
- return n.conn.Model(i).
- Relation("OriginAccount").
- Relation("TargetAccount").
- Relation("Status")
-}
-
-func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.Error) {
- if notification, cached := n.notificationCached(id); cached {
- return notification, nil
- }
-
- notification := >smodel.Notification{}
-
- q := n.newNotificationQ(notification).
- Where("notification.id = ?", id)
-
- err := processErrorResponse(q.Select())
-
- if err == nil && notification != nil {
- n.cacheNotification(id, notification)
- }
-
- return notification, err
-}
-
-func (n *notificationDB) GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
- // begin by selecting just the IDs
- notifIDs := []*gtsmodel.Notification{}
- q := n.conn.
- Model(¬ifIDs).
- Column("id").
- Where("target_account_id = ?", accountID).
- Order("id DESC")
-
- if maxID != "" {
- q = q.Where("id < ?", maxID)
- }
-
- if sinceID != "" {
- q = q.Where("id > ?", sinceID)
- }
-
- if limit != 0 {
- q = q.Limit(limit)
- }
-
- err := processErrorResponse(q.Select())
- if err != nil {
- return nil, err
- }
-
- // now we have the IDs, select the notifs one by one
- // reason for this is that for each notif, we can instead get it from our cache if it's cached
- notifications := []*gtsmodel.Notification{}
- for _, notifID := range notifIDs {
- notif, err := n.GetNotification(notifID.ID)
- errP := processErrorResponse(err)
- if errP != nil {
- return nil, errP
- }
- notifications = append(notifications, notif)
- }
-
- return notifications, nil
-}
diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go
deleted file mode 100644
index 0437baf02..000000000
--- a/internal/db/pg/pg.go
+++ /dev/null
@@ -1,420 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "context"
- "crypto/tls"
- "crypto/x509"
- "encoding/pem"
- "errors"
- "fmt"
- "os"
- "strings"
- "time"
-
- "github.com/go-pg/pg/extra/pgdebug"
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/id"
-)
-
-var registerTables []interface{} = []interface{}{
- >smodel.StatusToEmoji{},
- >smodel.StatusToTag{},
-}
-
-// postgresService satisfies the DB interface
-type postgresService struct {
- db.Account
- db.Admin
- db.Basic
- db.Domain
- db.Instance
- db.Media
- db.Mention
- db.Notification
- db.Relationship
- db.Status
- db.Timeline
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
-}
-
-// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
-// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
-func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
- for _, t := range registerTables {
- // https://pg.uptrace.dev/orm/many-to-many-relation/
- orm.RegisterTable(t)
- }
-
- opts, err := derivePGOptions(c)
- if err != nil {
- return nil, fmt.Errorf("could not create postgres service: %s", err)
- }
- log.Debugf("using pg options: %+v", opts)
-
- // create a connection
- pgCtx, cancel := context.WithCancel(ctx)
- conn := pg.Connect(opts).WithContext(pgCtx)
-
- // this will break the logfmt format we normally log in,
- // since we can't choose where pg outputs to and it defaults to
- // stdout. So use this option with care!
- if log.GetLevel() >= logrus.TraceLevel {
- conn.AddQueryHook(pgdebug.DebugHook{
- // Print all queries.
- Verbose: true,
- })
- }
-
- // actually *begin* the connection so that we can tell if the db is there and listening
- if err := conn.Ping(ctx); err != nil {
- cancel()
- return nil, fmt.Errorf("db connection error: %s", err)
- }
-
- // print out discovered postgres version
- var version string
- if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil {
- cancel()
- return nil, fmt.Errorf("db connection error: %s", err)
- }
- log.Infof("connected to postgres version: %s", version)
-
- ps := &postgresService{
- Account: &accountDB{
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- },
- Admin: &adminDB{
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- },
- Basic: &basicDB{
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- },
- Domain: &domainDB{
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- },
- Instance: &instanceDB{
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- },
- Media: &mediaDB{
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- },
- Mention: &mentionDB{
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- },
- Notification: ¬ificationDB{
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- },
- Relationship: &relationshipDB{
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- },
- Status: &statusDB{
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- },
- Timeline: &timelineDB{
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- },
- config: c,
- conn: conn,
- log: log,
- cancel: cancel,
- }
-
- // we can confidently return this useable postgres service now
- return ps, nil
-}
-
-/*
- HANDY STUFF
-*/
-
-// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options
-// with sensible defaults, or an error if it's not satisfied by the provided config.
-func derivePGOptions(c *config.Config) (*pg.Options, error) {
- if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
- return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
- }
-
- // validate port
- if c.DBConfig.Port == 0 {
- return nil, errors.New("no port set")
- }
-
- // validate address
- if c.DBConfig.Address == "" {
- return nil, errors.New("no address set")
- }
-
- // validate username
- if c.DBConfig.User == "" {
- return nil, errors.New("no user set")
- }
-
- // validate that there's a password
- if c.DBConfig.Password == "" {
- return nil, errors.New("no password set")
- }
-
- // validate database
- if c.DBConfig.Database == "" {
- return nil, errors.New("no database set")
- }
-
- var tlsConfig *tls.Config
- switch c.DBConfig.TLSMode {
- case config.DBTLSModeDisable, config.DBTLSModeUnset:
- break // nothing to do
- case config.DBTLSModeEnable:
- tlsConfig = &tls.Config{
- InsecureSkipVerify: true,
- }
- case config.DBTLSModeRequire:
- tlsConfig = &tls.Config{
- InsecureSkipVerify: false,
- ServerName: c.DBConfig.Address,
- }
- }
-
- if tlsConfig != nil && c.DBConfig.TLSCACert != "" {
- // load the system cert pool first -- we'll append the given CA cert to this
- certPool, err := x509.SystemCertPool()
- if err != nil {
- return nil, fmt.Errorf("error fetching system CA cert pool: %s", err)
- }
-
- // open the file itself and make sure there's something in it
- caCertBytes, err := os.ReadFile(c.DBConfig.TLSCACert)
- if err != nil {
- return nil, fmt.Errorf("error opening CA certificate at %s: %s", c.DBConfig.TLSCACert, err)
- }
- if len(caCertBytes) == 0 {
- return nil, fmt.Errorf("ca cert at %s was empty", c.DBConfig.TLSCACert)
- }
-
- // make sure we have a PEM block
- caPem, _ := pem.Decode(caCertBytes)
- if caPem == nil {
- return nil, fmt.Errorf("could not parse cert at %s into PEM", c.DBConfig.TLSCACert)
- }
-
- // parse the PEM block into the certificate
- caCert, err := x509.ParseCertificate(caPem.Bytes)
- if err != nil {
- return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", c.DBConfig.TLSCACert, err)
- }
-
- // we're happy, add it to the existing pool and then use this pool in our tls config
- certPool.AddCert(caCert)
- tlsConfig.RootCAs = certPool
- }
-
- // We can rely on the pg library we're using to set
- // sensible defaults for everything we don't set here.
- options := &pg.Options{
- Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port),
- User: c.DBConfig.User,
- Password: c.DBConfig.Password,
- Database: c.DBConfig.Database,
- ApplicationName: c.ApplicationName,
- TLSConfig: tlsConfig,
- }
-
- return options, nil
-}
-
-/*
- CONVERSION FUNCTIONS
-*/
-
-// TODO: move these to the type converter, it's bananas that they're here and not there
-
-func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
- ogAccount := >smodel.Account{}
- if err := ps.conn.Model(ogAccount).Where("id = ?", originAccountID).Select(); err != nil {
- return nil, err
- }
-
- menchies := []*gtsmodel.Mention{}
- for _, a := range targetAccounts {
- // A mentioned account looks like "@test@example.org" or just "@test" for a local account
- // -- we can guarantee this from the regex that targetAccounts should have been derived from.
- // But we still need to do a bit of fiddling to get what we need here -- the username and domain (if given).
-
- // 1. trim off the first @
- t := strings.TrimPrefix(a, "@")
-
- // 2. split the username and domain
- s := strings.Split(t, "@")
-
- // 3. if it's length 1 it's a local account, length 2 means remote, anything else means something is wrong
- var local bool
- switch len(s) {
- case 1:
- local = true
- case 2:
- local = false
- default:
- return nil, fmt.Errorf("mentioned account format '%s' was not valid", a)
- }
-
- var username, domain string
- username = s[0]
- if !local {
- domain = s[1]
- }
-
- // 4. check we now have a proper username and domain
- if username == "" || (!local && domain == "") {
- return nil, fmt.Errorf("username or domain for '%s' was nil", a)
- }
-
- // okay we're good now, we can start pulling accounts out of the database
- mentionedAccount := >smodel.Account{}
- var err error
-
- // match username + account, case insensitive
- if local {
- // local user -- should have a null domain
- err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("? IS NULL", pg.Ident("domain")).Select()
- } else {
- // remote user -- should have domain defined
- err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("LOWER(?) = LOWER(?)", pg.Ident("domain"), domain).Select()
- }
-
- if err != nil {
- if err == pg.ErrNoRows {
- // no result found for this username/domain so just don't include it as a mencho and carry on about our business
- ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
- continue
- }
- // a serious error has happened so bail
- return nil, fmt.Errorf("error getting account with username '%s' and domain '%s': %s", username, domain, err)
- }
-
- // id, createdAt and updatedAt will be populated by the db, so we have everything we need!
- menchies = append(menchies, >smodel.Mention{
- StatusID: statusID,
- OriginAccountID: ogAccount.ID,
- OriginAccountURI: ogAccount.URI,
- TargetAccountID: mentionedAccount.ID,
- NameString: a,
- TargetAccountURI: mentionedAccount.URI,
- TargetAccountURL: mentionedAccount.URL,
- OriginAccount: mentionedAccount,
- })
- }
- return menchies, nil
-}
-
-func (ps *postgresService) TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
- newTags := []*gtsmodel.Tag{}
- for _, t := range tags {
- tag := >smodel.Tag{}
- // we can use selectorinsert here to create the new tag if it doesn't exist already
- // inserted will be true if this is a new tag we just created
- if err := ps.conn.Model(tag).Where("LOWER(?) = LOWER(?)", pg.Ident("name"), t).Select(); err != nil {
- if err == pg.ErrNoRows {
- // tag doesn't exist yet so populate it
- newID, err := id.NewRandomULID()
- if err != nil {
- return nil, err
- }
- tag.ID = newID
- tag.URL = fmt.Sprintf("%s://%s/tags/%s", ps.config.Protocol, ps.config.Host, t)
- tag.Name = t
- tag.FirstSeenFromAccountID = originAccountID
- tag.CreatedAt = time.Now()
- tag.UpdatedAt = time.Now()
- tag.Useable = true
- tag.Listable = true
- } else {
- return nil, fmt.Errorf("error getting tag with name %s: %s", t, err)
- }
- }
-
- // bail already if the tag isn't useable
- if !tag.Useable {
- continue
- }
- tag.LastStatusAt = time.Now()
- newTags = append(newTags, tag)
- }
- return newTags, nil
-}
-
-func (ps *postgresService) EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
- newEmojis := []*gtsmodel.Emoji{}
- for _, e := range emojis {
- emoji := >smodel.Emoji{}
- err := ps.conn.Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Select()
- if err != nil {
- if err == pg.ErrNoRows {
- // no result found for this username/domain so just don't include it as an emoji and carry on about our business
- ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
- continue
- }
- // a serious error has happened so bail
- return nil, fmt.Errorf("error getting emoji with shortcode %s: %s", e, err)
- }
- newEmojis = append(newEmojis, emoji)
- }
- return newEmojis, nil
-}
diff --git a/internal/db/pg/pg_test.go b/internal/db/pg/pg_test.go
deleted file mode 100644
index c1e10abdf..000000000
--- a/internal/db/pg/pg_test.go
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg_test
-
-import (
- "github.com/sirupsen/logrus"
- "github.com/stretchr/testify/suite"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/oauth"
-)
-
-type PGStandardTestSuite struct {
- // standard suite interfaces
- suite.Suite
- config *config.Config
- db db.DB
- log *logrus.Logger
-
- // standard suite models
- testTokens map[string]*oauth.Token
- testClients map[string]*oauth.Client
- testApplications map[string]*gtsmodel.Application
- testUsers map[string]*gtsmodel.User
- testAccounts map[string]*gtsmodel.Account
- testAttachments map[string]*gtsmodel.MediaAttachment
- testStatuses map[string]*gtsmodel.Status
- testTags map[string]*gtsmodel.Tag
- testMentions map[string]*gtsmodel.Mention
-}
diff --git a/internal/db/pg/relationship.go b/internal/db/pg/relationship.go
deleted file mode 100644
index 76bd50c76..000000000
--- a/internal/db/pg/relationship.go
+++ /dev/null
@@ -1,276 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "context"
- "fmt"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-type relationshipDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
-}
-
-func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query {
- return r.conn.Model(block).
- Relation("Account").
- Relation("TargetAccount")
-}
-
-func (r *relationshipDB) newFollowQ(follow interface{}) *orm.Query {
- return r.conn.Model(follow).
- Relation("Account").
- Relation("TargetAccount")
-}
-
-func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
- q := r.conn.
- Model(>smodel.Block{}).
- Where("account_id = ?", account1).
- Where("target_account_id = ?", account2)
-
- if eitherDirection {
- q = q.
- WhereOr("target_account_id = ?", account1).
- Where("account_id = ?", account2)
- }
-
- return q.Exists()
-}
-
-func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) {
- block := >smodel.Block{}
-
- q := r.newBlockQ(block).
- Where("block.account_id = ?", account1).
- Where("block.target_account_id = ?", account2)
-
- err := processErrorResponse(q.Select())
-
- return block, err
-}
-
-func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
- rel := >smodel.Relationship{
- ID: targetAccount,
- }
-
- // check if the requesting account follows the target account
- follow := >smodel.Follow{}
- if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
- if err != pg.ErrNoRows {
- // a proper error
- return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
- }
- // no follow exists so these are all false
- rel.Following = false
- rel.ShowingReblogs = false
- rel.Notifying = false
- } else {
- // follow exists so we can fill these fields out...
- rel.Following = true
- rel.ShowingReblogs = follow.ShowReblogs
- rel.Notifying = follow.Notify
- }
-
- // check if the target account follows the requesting account
- followedBy, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
- if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
- }
- rel.FollowedBy = followedBy
-
- // check if the requesting account blocks the target account
- blocking, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
- if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
- }
- rel.Blocking = blocking
-
- // check if the target account blocks the requesting account
- blockedBy, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
- if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
- }
- rel.BlockedBy = blockedBy
-
- // check if there's a pending following request from requesting account to target account
- requested, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
- if err != nil {
- return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
- }
- rel.Requested = requested
-
- return rel, nil
-}
-
-func (r *relationshipDB) IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
- if sourceAccount == nil || targetAccount == nil {
- return false, nil
- }
-
- q := r.conn.
- Model(>smodel.Follow{}).
- Where("account_id = ?", sourceAccount.ID).
- Where("target_account_id = ?", targetAccount.ID)
-
- return q.Exists()
-}
-
-func (r *relationshipDB) IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
- if sourceAccount == nil || targetAccount == nil {
- return false, nil
- }
-
- q := r.conn.
- Model(>smodel.FollowRequest{}).
- Where("account_id = ?", sourceAccount.ID).
- Where("target_account_id = ?", targetAccount.ID)
-
- return q.Exists()
-}
-
-func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
- if account1 == nil || account2 == nil {
- return false, nil
- }
-
- // make sure account 1 follows account 2
- f1, err := r.IsFollowing(account1, account2)
- if err != nil {
- return false, processErrorResponse(err)
- }
-
- // make sure account 2 follows account 1
- f2, err := r.IsFollowing(account2, account1)
- if err != nil {
- return false, processErrorResponse(err)
- }
-
- return f1 && f2, nil
-}
-
-func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
- // make sure the original follow request exists
- fr := >smodel.FollowRequest{}
- if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
- if err == pg.ErrMultiRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
- }
-
- // create a new follow to 'replace' the request with
- follow := >smodel.Follow{
- ID: fr.ID,
- AccountID: originAccountID,
- TargetAccountID: targetAccountID,
- URI: fr.URI,
- }
-
- // if the follow already exists, just update the URI -- we don't need to do anything else
- if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
- return nil, err
- }
-
- // now remove the follow request
- if _, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
- return nil, err
- }
-
- return follow, nil
-}
-
-func (r *relationshipDB) GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
- followRequests := []*gtsmodel.FollowRequest{}
-
- q := r.newFollowQ(&followRequests).
- Where("target_account_id = ?", accountID)
-
- err := processErrorResponse(q.Select())
-
- return followRequests, err
-}
-
-func (r *relationshipDB) GetAccountFollows(accountID string) ([]*gtsmodel.Follow, db.Error) {
- follows := []*gtsmodel.Follow{}
-
- q := r.newFollowQ(&follows).
- Where("account_id = ?", accountID)
-
- err := processErrorResponse(q.Select())
-
- return follows, err
-}
-
-func (r *relationshipDB) CountAccountFollows(accountID string, localOnly bool) (int, db.Error) {
- return r.conn.
- Model(&[]*gtsmodel.Follow{}).
- Where("account_id = ?", accountID).
- Count()
-}
-
-func (r *relationshipDB) GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
-
- follows := []*gtsmodel.Follow{}
-
- q := r.conn.Model(&follows)
-
- if localOnly {
- // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
- whereGroup := func(q *pg.Query) (*pg.Query, error) {
- q = q.
- WhereOr("? IS NULL", pg.Ident("a.domain")).
- WhereOr("a.domain = ?", "")
- return q, nil
- }
-
- q = q.ColumnExpr("follow.*").
- Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
- Where("follow.target_account_id = ?", accountID).
- WhereGroup(whereGroup)
- } else {
- q = q.Where("target_account_id = ?", accountID)
- }
-
- if err := q.Select(); err != nil {
- if err == pg.ErrNoRows {
- return follows, nil
- }
- return nil, err
- }
- return follows, nil
-}
-
-func (r *relationshipDB) CountAccountFollowedBy(accountID string, localOnly bool) (int, db.Error) {
- return r.conn.
- Model(&[]*gtsmodel.Follow{}).
- Where("target_account_id = ?", accountID).
- Count()
-}
diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go
deleted file mode 100644
index 99790428e..000000000
--- a/internal/db/pg/status.go
+++ /dev/null
@@ -1,318 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "container/list"
- "context"
- "errors"
- "time"
-
- "github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/cache"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-type statusDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
- cache cache.Cache
-}
-
-func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
- if s.cache == nil {
- s.cache = cache.New()
- }
-
- if err := s.cache.Store(id, status); err != nil {
- s.log.Panicf("statusDB: error storing in cache: %s", err)
- }
-}
-
-func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
- if s.cache == nil {
- s.cache = cache.New()
- return nil, false
- }
-
- sI, err := s.cache.Fetch(id)
- if err != nil || sI == nil {
- return nil, false
- }
-
- status, ok := sI.(*gtsmodel.Status)
- if !ok {
- s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
- }
-
- return status, true
-}
-
-func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
- return s.conn.Model(status).
- Relation("Attachments").
- Relation("Tags").
- Relation("Mentions").
- Relation("Emojis").
- Relation("Account").
- Relation("InReplyTo").
- Relation("InReplyToAccount").
- Relation("BoostOf").
- Relation("BoostOfAccount").
- Relation("CreatedWithApplication")
-}
-
-func (s *statusDB) newFaveQ(faves interface{}) *orm.Query {
- return s.conn.Model(faves).
- Relation("Account").
- Relation("TargetAccount").
- Relation("Status")
-}
-
-func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(id); cached {
- return status, nil
- }
-
- status := >smodel.Status{}
-
- q := s.newStatusQ(status).
- Where("status.id = ?", id)
-
- err := processErrorResponse(q.Select())
-
- if err == nil && status != nil {
- s.cacheStatus(id, status)
- }
-
- return status, err
-}
-
-func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(uri); cached {
- return status, nil
- }
-
- status := >smodel.Status{}
-
- q := s.newStatusQ(status).
- Where("LOWER(status.uri) = LOWER(?)", uri)
-
- err := processErrorResponse(q.Select())
-
- if err == nil && status != nil {
- s.cacheStatus(uri, status)
- }
-
- return status, err
-}
-
-func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(uri); cached {
- return status, nil
- }
-
- status := >smodel.Status{}
-
- q := s.newStatusQ(status).
- Where("LOWER(status.url) = LOWER(?)", uri)
-
- err := processErrorResponse(q.Select())
-
- if err == nil && status != nil {
- s.cacheStatus(uri, status)
- }
-
- return status, err
-}
-
-func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
- transaction := func(tx *pg.Tx) error {
- // create links between this status and any emojis it uses
- for _, i := range status.EmojiIDs {
- if _, err := tx.Model(>smodel.StatusToEmoji{
- StatusID: status.ID,
- EmojiID: i,
- }).Insert(); err != nil {
- return err
- }
- }
-
- // create links between this status and any tags it uses
- for _, i := range status.TagIDs {
- if _, err := tx.Model(>smodel.StatusToTag{
- StatusID: status.ID,
- TagID: i,
- }).Insert(); err != nil {
- return err
- }
- }
-
- // change the status ID of the media attachments to the new status
- for _, a := range status.Attachments {
- a.StatusID = status.ID
- a.UpdatedAt = time.Now()
- if _, err := s.conn.Model(a).
- Where("id = ?", a.ID).
- Update(); err != nil {
- return err
- }
- }
-
- _, err := tx.Model(status).Insert()
- return err
- }
-
- return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction))
-}
-
-func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
- parents := []*gtsmodel.Status{}
- s.statusParent(status, &parents, onlyDirect)
-
- return parents, nil
-}
-
-func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
- if status.InReplyToID == "" {
- return
- }
-
- parentStatus, err := s.GetStatusByID(status.InReplyToID)
- if err == nil {
- *foundStatuses = append(*foundStatuses, parentStatus)
- }
-
- if onlyDirect {
- return
- }
-
- s.statusParent(parentStatus, foundStatuses, false)
-}
-
-func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
- foundStatuses := &list.List{}
- foundStatuses.PushFront(status)
- s.statusChildren(status, foundStatuses, onlyDirect, minID)
-
- children := []*gtsmodel.Status{}
- for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
- // only append children, not the overall parent status
- if entry.ID != status.ID {
- children = append(children, entry)
- }
- }
-
- return children, nil
-}
-
-func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
- immediateChildren := []*gtsmodel.Status{}
-
- q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
- if minID != "" {
- q = q.Where("status.id > ?", minID)
- }
-
- if err := q.Select(); err != nil {
- return
- }
-
- for _, child := range immediateChildren {
- insertLoop:
- for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
- if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
- foundStatuses.InsertAfter(child, e)
- break insertLoop
- }
- }
-
- // only do one loop if we only want direct children
- if onlyDirect {
- return
- }
- s.statusChildren(child, foundStatuses, false, minID)
- }
-}
-
-func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) {
- return s.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
-}
-
-func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) {
- return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
-}
-
-func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) {
- return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
-}
-
-func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
- return s.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
-}
-
-func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
- faves := []*gtsmodel.StatusFave{}
-
- q := s.newFaveQ(&faves).
- Where("status_id = ?", status.ID)
-
- err := processErrorResponse(q.Select())
-
- return faves, err
-}
-
-func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
- reblogs := []*gtsmodel.Status{}
-
- q := s.newStatusQ(&reblogs).
- Where("boost_of_id = ?", status.ID)
-
- err := processErrorResponse(q.Select())
-
- return reblogs, err
-}
diff --git a/internal/db/pg/status_test.go b/internal/db/pg/status_test.go
deleted file mode 100644
index 8a185757c..000000000
--- a/internal/db/pg/status_test.go
+++ /dev/null
@@ -1,134 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg_test
-
-import (
- "fmt"
- "testing"
- "time"
-
- "github.com/stretchr/testify/suite"
- "github.com/superseriousbusiness/gotosocial/testrig"
-)
-
-type StatusTestSuite struct {
- PGStandardTestSuite
-}
-
-func (suite *StatusTestSuite) SetupSuite() {
- suite.testTokens = testrig.NewTestTokens()
- suite.testClients = testrig.NewTestClients()
- suite.testApplications = testrig.NewTestApplications()
- suite.testUsers = testrig.NewTestUsers()
- suite.testAccounts = testrig.NewTestAccounts()
- suite.testAttachments = testrig.NewTestAttachments()
- suite.testStatuses = testrig.NewTestStatuses()
- suite.testTags = testrig.NewTestTags()
- suite.testMentions = testrig.NewTestMentions()
-}
-
-func (suite *StatusTestSuite) SetupTest() {
- suite.config = testrig.NewTestConfig()
- suite.db = testrig.NewTestDB()
- suite.log = testrig.NewTestLog()
-
- testrig.StandardDBSetup(suite.db, suite.testAccounts)
-}
-
-func (suite *StatusTestSuite) TearDownTest() {
- testrig.StandardDBTeardown(suite.db)
-}
-
-func (suite *StatusTestSuite) TestGetStatusByID() {
- status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
- if err != nil {
- suite.FailNow(err.Error())
- }
- suite.NotNil(status)
- suite.NotNil(status.Account)
- suite.NotNil(status.CreatedWithApplication)
- suite.Nil(status.BoostOf)
- suite.Nil(status.BoostOfAccount)
- suite.Nil(status.InReplyTo)
- suite.Nil(status.InReplyToAccount)
-}
-
-func (suite *StatusTestSuite) TestGetStatusByURI() {
- status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
- if err != nil {
- suite.FailNow(err.Error())
- }
- suite.NotNil(status)
- suite.NotNil(status.Account)
- suite.NotNil(status.CreatedWithApplication)
- suite.Nil(status.BoostOf)
- suite.Nil(status.BoostOfAccount)
- suite.Nil(status.InReplyTo)
- suite.Nil(status.InReplyToAccount)
-}
-
-func (suite *StatusTestSuite) TestGetStatusWithExtras() {
- status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID)
- if err != nil {
- suite.FailNow(err.Error())
- }
- suite.NotNil(status)
- suite.NotNil(status.Account)
- suite.NotNil(status.CreatedWithApplication)
- suite.NotEmpty(status.Tags)
- suite.NotEmpty(status.Attachments)
- suite.NotEmpty(status.Emojis)
-}
-
-func (suite *StatusTestSuite) TestGetStatusWithMention() {
- status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID)
- if err != nil {
- suite.FailNow(err.Error())
- }
- suite.NotNil(status)
- suite.NotNil(status.Account)
- suite.NotNil(status.CreatedWithApplication)
- suite.NotEmpty(status.Mentions)
- suite.NotEmpty(status.MentionIDs)
- suite.NotNil(status.InReplyTo)
- suite.NotNil(status.InReplyToAccount)
-}
-
-func (suite *StatusTestSuite) TestGetStatusTwice() {
- before1 := time.Now()
- _, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
- suite.NoError(err)
- after1 := time.Now()
- duration1 := after1.Sub(before1)
- fmt.Println(duration1.Nanoseconds())
-
- before2 := time.Now()
- _, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
- suite.NoError(err)
- after2 := time.Now()
- duration2 := after2.Sub(before2)
- fmt.Println(duration2.Nanoseconds())
-
- // second retrieval should be several orders faster since it will be cached now
- suite.Less(duration2, duration1)
-}
-
-func TestStatusTestSuite(t *testing.T) {
- suite.Run(t, new(StatusTestSuite))
-}
diff --git a/internal/db/pg/timeline.go b/internal/db/pg/timeline.go
deleted file mode 100644
index fa8b07aab..000000000
--- a/internal/db/pg/timeline.go
+++ /dev/null
@@ -1,210 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "context"
- "sort"
-
- "github.com/go-pg/pg/v10"
- "github.com/sirupsen/logrus"
- "github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-type timelineDB struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
-}
-
-func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
- statuses := []*gtsmodel.Status{}
- q := t.conn.Model(&statuses)
-
- q = q.ColumnExpr("status.*").
- // Find out who accountID follows.
- Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id").
- // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
- // OR statuses posted by accountID itself (since a user should be able to see their own statuses).
- //
- // This is equivalent to something like WHERE ... AND (... OR ...)
- // See: https://pg.uptrace.dev/queries/#select
- WhereGroup(func(q *pg.Query) (*pg.Query, error) {
- q = q.WhereOr("f.account_id = ?", accountID).
- WhereOr("status.account_id = ?", accountID)
- return q, nil
- }).
- // Sort by highest ID (newest) to lowest ID (oldest)
- Order("status.id DESC")
-
- if maxID != "" {
- // return only statuses LOWER (ie., older) than maxID
- q = q.Where("status.id < ?", maxID)
- }
-
- if sinceID != "" {
- // return only statuses HIGHER (ie., newer) than sinceID
- q = q.Where("status.id > ?", sinceID)
- }
-
- if minID != "" {
- // return only statuses HIGHER (ie., newer) than minID
- q = q.Where("status.id > ?", minID)
- }
-
- if local {
- // return only statuses posted by local account havers
- q = q.Where("status.local = ?", local)
- }
-
- if limit > 0 {
- // limit amount of statuses returned
- q = q.Limit(limit)
- }
-
- err := q.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
- }
-
- if len(statuses) == 0 {
- return nil, db.ErrNoEntries
- }
-
- return statuses, nil
-}
-
-func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
- statuses := []*gtsmodel.Status{}
-
- q := t.conn.Model(&statuses).
- Where("visibility = ?", gtsmodel.VisibilityPublic).
- Where("? IS NULL", pg.Ident("in_reply_to_id")).
- Where("? IS NULL", pg.Ident("in_reply_to_uri")).
- Where("? IS NULL", pg.Ident("boost_of_id")).
- Order("status.id DESC")
-
- if maxID != "" {
- q = q.Where("status.id < ?", maxID)
- }
-
- if sinceID != "" {
- q = q.Where("status.id > ?", sinceID)
- }
-
- if minID != "" {
- q = q.Where("status.id > ?", minID)
- }
-
- if local {
- q = q.Where("status.local = ?", local)
- }
-
- if limit > 0 {
- q = q.Limit(limit)
- }
-
- err := q.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries
- }
- return nil, err
- }
-
- if len(statuses) == 0 {
- return nil, db.ErrNoEntries
- }
-
- return statuses, nil
-}
-
-// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
-// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
-func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
-
- faves := []*gtsmodel.StatusFave{}
-
- fq := t.conn.Model(&faves).
- Where("account_id = ?", accountID).
- Order("id DESC")
-
- if maxID != "" {
- fq = fq.Where("id < ?", maxID)
- }
-
- if minID != "" {
- fq = fq.Where("id > ?", minID)
- }
-
- if limit > 0 {
- fq = fq.Limit(limit)
- }
-
- err := fq.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries
- }
- return nil, "", "", err
- }
-
- if len(faves) == 0 {
- return nil, "", "", db.ErrNoEntries
- }
-
- // map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
- statusesFavesMap := map[string]string{}
-
- in := []string{}
- for _, f := range faves {
- statusesFavesMap[f.StatusID] = f.ID
- in = append(in, f.StatusID)
- }
-
- statuses := []*gtsmodel.Status{}
- err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries
- }
- return nil, "", "", err
- }
-
- if len(statuses) == 0 {
- return nil, "", "", db.ErrNoEntries
- }
-
- // arrange statuses by fave ID
- sort.Slice(statuses, func(i int, j int) bool {
- statusI := statuses[i]
- statusJ := statuses[j]
- return statusesFavesMap[statusI.ID] < statusesFavesMap[statusJ.ID]
- })
-
- nextMaxID := faves[len(faves)-1].ID
- prevMinID := faves[0].ID
- return statuses, nextMaxID, prevMinID, nil
-}
diff --git a/internal/db/pg/util.go b/internal/db/pg/util.go
deleted file mode 100644
index 17c09b720..000000000
--- a/internal/db/pg/util.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package pg
-
-import (
- "strings"
-
- "github.com/go-pg/pg/v10"
- "github.com/superseriousbusiness/gotosocial/internal/db"
-)
-
-// processErrorResponse parses the given error and returns an appropriate DBError.
-func processErrorResponse(err error) db.Error {
- switch err {
- case nil:
- return nil
- case pg.ErrNoRows:
- return db.ErrNoEntries
- case pg.ErrMultiRows:
- return db.ErrMultipleEntries
- default:
- if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
- return db.ErrAlreadyExists
- }
- return err
- }
-}
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index 85f64d72b..804526425 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -18,54 +18,58 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
// IsBlocked checks whether account 1 has a block in place against block2.
// If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
- IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error)
+ IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
//
// Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
// not if you're just checking for the existence of a block.
- GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error)
+ GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
- GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
+ GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
- IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
- IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
- IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
+ IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
- AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
+ AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
// GetAccountFollowRequests returns all follow requests targeting the given account.
- GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error)
+ GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, Error)
// GetAccountFollows returns a slice of follows owned by the given accountID.
- GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error)
+ GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, Error)
// CountAccountFollows returns the amount of accounts that the given accountID is following.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
- CountAccountFollows(accountID string, localOnly bool) (int, Error)
+ CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, Error)
// GetAccountFollowedBy fetches follows that target given accountID.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
- GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
+ GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
// CountAccountFollowedBy returns the amounts that the given ID is followed by.
- CountAccountFollowedBy(accountID string, localOnly bool) (int, Error)
+ CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, Error)
}
diff --git a/internal/db/session.go b/internal/db/session.go
new file mode 100644
index 000000000..ae13dccce
--- /dev/null
+++ b/internal/db/session.go
@@ -0,0 +1,31 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+// Session handles getting/creation of router sessions.
+type Session interface {
+ GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
+ CreateSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
+}
diff --git a/internal/db/status.go b/internal/db/status.go
index 9d206c198..7430433c4 100644
--- a/internal/db/status.go
+++ b/internal/db/status.go
@@ -18,58 +18,62 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface {
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
- GetStatusByID(id string) (*gtsmodel.Status, Error)
+ GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error)
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
- GetStatusByURI(uri string) (*gtsmodel.Status, Error)
+ GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
- GetStatusByURL(uri string) (*gtsmodel.Status, Error)
+ GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// PutStatus stores one status in the database.
- PutStatus(status *gtsmodel.Status) Error
+ PutStatus(ctx context.Context, status *gtsmodel.Status) Error
// CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong
- CountStatusReplies(status *gtsmodel.Status) (int, Error)
+ CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, Error)
// CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
- CountStatusReblogs(status *gtsmodel.Status) (int, Error)
+ CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, Error)
// CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong
- CountStatusFaves(status *gtsmodel.Status) (int, Error)
+ CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, Error)
// GetStatusParents gets the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
- GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
+ GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
// GetStatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
- GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
+ GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
// IsStatusFavedBy checks if a given status has been faved by a given account ID
- IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
- IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusMutedBy checks if a given status has been muted by a given account ID
- IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
- IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error)
+ IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
+ GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
// GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
+ GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
}
diff --git a/internal/db/timeline.go b/internal/db/timeline.go
index 74aa5c781..83fb3a959 100644
--- a/internal/db/timeline.go
+++ b/internal/db/timeline.go
@@ -18,20 +18,24 @@
package db
-import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
// Timeline contains functionality for retrieving home/public/faved etc timelines for an account.
type Timeline interface {
// GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
- GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
+ GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
- GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
+ GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
@@ -40,5 +44,5 @@ type Timeline interface {
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
- GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
+ GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
}
--
cgit v1.2.3