diff options
author | 2021-08-25 15:34:33 +0200 | |
---|---|---|
committer | 2021-08-25 15:34:33 +0200 | |
commit | 2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch) | |
tree | 4ddeac479b923db38090aac8bd9209f3646851c1 /internal/db/bundb | |
parent | Manually approves followers (#146) (diff) | |
download | gotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz |
Pg to bun (#148)
* start moving to bun
* changing more stuff
* more
* and yet more
* tests passing
* seems stable now
* more big changes
* small fix
* little fixes
Diffstat (limited to 'internal/db/bundb')
-rw-r--r-- | internal/db/bundb/account.go | 291 | ||||
-rw-r--r-- | internal/db/bundb/account_test.go | 86 | ||||
-rw-r--r-- | internal/db/bundb/admin.go | 272 | ||||
-rw-r--r-- | internal/db/bundb/basic.go | 179 | ||||
-rw-r--r-- | internal/db/bundb/basic_test.go | 68 | ||||
-rw-r--r-- | internal/db/bundb/bundb.go | 410 | ||||
-rw-r--r-- | internal/db/bundb/bundb_test.go | 47 | ||||
-rw-r--r-- | internal/db/bundb/domain.go | 81 | ||||
-rw-r--r-- | internal/db/bundb/instance.go | 118 | ||||
-rw-r--r-- | internal/db/bundb/media.go | 53 | ||||
-rw-r--r-- | internal/db/bundb/mention.go | 108 | ||||
-rw-r--r-- | internal/db/bundb/notification.go | 136 | ||||
-rw-r--r-- | internal/db/bundb/relationship.go | 328 | ||||
-rw-r--r-- | internal/db/bundb/session.go | 85 | ||||
-rw-r--r-- | internal/db/bundb/status.go | 375 | ||||
-rw-r--r-- | internal/db/bundb/status_test.go | 136 | ||||
-rw-r--r-- | internal/db/bundb/timeline.go | 199 | ||||
-rw-r--r-- | internal/db/bundb/util.go | 78 |
18 files changed, 3050 insertions, 0 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go new file mode 100644 index 000000000..7ebb79a15 --- /dev/null +++ b/internal/db/bundb/account.go @@ -0,0 +1,291 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type accountDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { + return a.conn. + NewSelect(). + Model(account). + Relation("AvatarMediaAttachment"). + Relation("HeaderMediaAttachment") +} + +func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) + + q := a.newAccountQ(account). + Where("account.id = ?", id) + + err := processErrorResponse(q.Scan(ctx)) + + return account, err +} + +func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) + + q := a.newAccountQ(account). + Where("account.uri = ?", uri) + + err := processErrorResponse(q.Scan(ctx)) + + return account, err +} + +func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) + + q := a.newAccountQ(account). + Where("account.url = ?", uri) + + err := processErrorResponse(q.Scan(ctx)) + + return account, err +} + +func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { + if strings.TrimSpace(account.ID) == "" { + return nil, errors.New("account had no ID") + } + + account.UpdatedAt = time.Now() + + q := a.conn. + NewUpdate(). + Model(account). + WherePK() + + _, err := q.Exec(ctx) + + err = processErrorResponse(err) + + return account, err +} + +func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) + + q := a.newAccountQ(account) + + if domain == "" { + q = q. + Where("account.username = ?", domain). + Where("account.domain = ?", domain) + } else { + q = q. + Where("account.username = ?", domain). + Where("? IS NULL", bun.Ident("domain")) + } + + err := processErrorResponse(q.Scan(ctx)) + + return account, err +} + +func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) { + status := new(gtsmodel.Status) + + q := a.conn. + NewSelect(). + Model(status). + Order("id DESC"). + Limit(1). + Where("account_id = ?", accountID). + Column("created_at") + + err := processErrorResponse(q.Scan(ctx)) + + return status.CreatedAt, err +} + +func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { + if mediaAttachment.Avatar && mediaAttachment.Header { + return errors.New("one media attachment cannot be both header and avatar") + } + + var headerOrAVI string + if mediaAttachment.Avatar { + headerOrAVI = "avatar" + } else if mediaAttachment.Header { + headerOrAVI = "header" + } else { + return errors.New("given media attachment was neither a header nor an avatar") + } + + // TODO: there are probably more side effects here that need to be handled + if _, err := a.conn. + NewInsert(). + Model(mediaAttachment). + Exec(ctx); err != nil { + return err + } + + if _, err := a.conn. + NewUpdate(). + Model(>smodel.Account{}). + Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID). + Where("id = ?", accountID). + Exec(ctx); err != nil { + return err + } + return nil +} + +func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, db.Error) { + account := new(gtsmodel.Account) + + q := a.newAccountQ(account). + Where("username = ?", username). + Where("? IS NULL", bun.Ident("domain")) + + err := processErrorResponse(q.Scan(ctx)) + + return account, err +} + +func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) { + faves := new([]*gtsmodel.StatusFave) + + if err := a.conn. + NewSelect(). + Model(faves). + Where("account_id = ?", accountID). + Scan(ctx); err != nil { + return nil, err + } + return *faves, nil +} + +func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { + return a.conn. + NewSelect(). + Model(>smodel.Status{}). + Where("account_id = ?", accountID). + Count(ctx) +} + +func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) { + a.log.Debugf("getting statuses for account %s", accountID) + statuses := []*gtsmodel.Status{} + + q := a.conn. + NewSelect(). + Model(&statuses). + Order("id DESC") + + if accountID != "" { + q = q.Where("account_id = ?", accountID) + } + + if limit != 0 { + q = q.Limit(limit) + } + + if excludeReplies { + q = q.Where("? IS NULL", bun.Ident("in_reply_to_id")) + } + + if pinnedOnly { + q = q.Where("pinned = ?", true) + } + + if maxID != "" { + q = q.Where("id < ?", maxID) + } + + if mediaOnly { + q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { + return q. + WhereOr("? IS NOT NULL", bun.Ident("attachments")). + WhereOr("attachments != '{}'") + }) + } + + if err := q.Scan(ctx); err != nil { + return nil, err + } + + if len(statuses) == 0 { + return nil, db.ErrNoEntries + } + + a.log.Debugf("returning statuses for account %s", accountID) + return statuses, nil +} + +func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { + blocks := []*gtsmodel.Block{} + + fq := a.conn. + NewSelect(). + Model(&blocks). + Where("block.account_id = ?", accountID). + Relation("TargetAccount"). + Order("block.id DESC") + + if maxID != "" { + fq = fq.Where("block.id < ?", maxID) + } + + if sinceID != "" { + fq = fq.Where("block.id > ?", sinceID) + } + + if limit > 0 { + fq = fq.Limit(limit) + } + + err := fq.Scan(ctx) + if err != nil { + return nil, "", "", err + } + + if len(blocks) == 0 { + return nil, "", "", db.ErrNoEntries + } + + accounts := []*gtsmodel.Account{} + for _, b := range blocks { + accounts = append(accounts, b.TargetAccount) + } + + nextMaxID := blocks[len(blocks)-1].ID + prevMinID := blocks[0].ID + return accounts, nextMaxID, prevMinID, nil +} diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go new file mode 100644 index 000000000..7174b781d --- /dev/null +++ b/internal/db/bundb/account_test.go @@ -0,0 +1,86 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type AccountTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *AccountTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testTags = testrig.NewTestTags() + suite.testMentions = testrig.NewTestMentions() +} + +func (suite *AccountTestSuite) SetupTest() { + suite.config = testrig.NewTestConfig() + suite.db = testrig.NewTestDB() + suite.log = testrig.NewTestLog() + + testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *AccountTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { + account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(account) + suite.NotNil(account.AvatarMediaAttachment) + suite.NotEmpty(account.AvatarMediaAttachment.URL) + suite.NotNil(account.HeaderMediaAttachment) + suite.NotEmpty(account.HeaderMediaAttachment.URL) +} + +func (suite *AccountTestSuite) TestUpdateAccount() { + testAccount := suite.testAccounts["local_account_1"] + + testAccount.DisplayName = "new display name!" + + _, err := suite.db.UpdateAccount(context.Background(), testAccount) + suite.NoError(err) + + updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID) + suite.NoError(err) + suite.Equal("new display name!", updated.DisplayName) + suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second) +} + +func TestAccountTestSuite(t *testing.T) { + suite.Run(t, new(AccountTestSuite)) +} diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go new file mode 100644 index 000000000..67a1e8a0d --- /dev/null +++ b/internal/db/bundb/admin.go @@ -0,0 +1,272 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "database/sql" + "fmt" + "net" + "net/mail" + "strings" + "time" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" + "golang.org/x/crypto/bcrypt" +) + +type adminDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { + q := a.conn. + NewSelect(). + Model(>smodel.Account{}). + Where("username = ?", username). + Where("domain = ?", nil) + + return notExists(ctx, q) +} + +func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) { + // parse the domain from the email + m, err := mail.ParseAddress(email) + if err != nil { + return false, fmt.Errorf("error parsing email address %s: %s", email, err) + } + domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ + + // check if the email domain is blocked + if err := a.conn. + NewSelect(). + Model(>smodel.EmailDomainBlock{}). + Where("domain = ?", domain). + Scan(ctx); err == nil { + // fail because we found something + return false, fmt.Errorf("email domain %s is blocked", domain) + } else if err != sql.ErrNoRows { + return false, processErrorResponse(err) + } + + // check if this email is associated with a user already + q := a.conn. + NewSelect(). + Model(>smodel.User{}). + Where("email = ?", email). + WhereOr("unconfirmed_email = ?", email) + + return notExists(ctx, q) +} + +func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + a.log.Errorf("error creating new rsa key: %s", err) + return nil, err + } + + // if something went wrong while creating a user, we might already have an account, so check here first... + acct := >smodel.Account{} + err = a.conn.NewSelect(). + Model(acct). + Where("username = ?", username). + Where("? IS NULL", bun.Ident("domain")). + Scan(ctx) + if err != nil { + // we just don't have an account yet create one + newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host) + newAccountID, err := id.NewRandomULID() + if err != nil { + return nil, err + } + + acct = >smodel.Account{ + ID: newAccountID, + Username: username, + DisplayName: username, + Reason: reason, + URL: newAccountURIs.UserURL, + PrivateKey: key, + PublicKey: &key.PublicKey, + PublicKeyURI: newAccountURIs.PublicKeyURI, + ActorType: gtsmodel.ActivityStreamsPerson, + URI: newAccountURIs.UserURI, + InboxURI: newAccountURIs.InboxURI, + OutboxURI: newAccountURIs.OutboxURI, + FollowersURI: newAccountURIs.FollowersURI, + FollowingURI: newAccountURIs.FollowingURI, + FeaturedCollectionURI: newAccountURIs.CollectionURI, + } + if _, err = a.conn. + NewInsert(). + Model(acct). + Exec(ctx); err != nil { + return nil, err + } + } + + pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, fmt.Errorf("error hashing password: %s", err) + } + + newUserID, err := id.NewRandomULID() + if err != nil { + return nil, err + } + + u := >smodel.User{ + ID: newUserID, + AccountID: acct.ID, + EncryptedPassword: string(pw), + SignUpIP: signUpIP.To4(), + Locale: locale, + UnconfirmedEmail: email, + CreatedByApplicationID: appID, + Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user + } + + if emailVerified { + u.ConfirmedAt = time.Now() + u.Email = email + } + + if admin { + u.Admin = true + u.Moderator = true + } + + if _, err = a.conn. + NewInsert(). + Model(u). + Exec(ctx); err != nil { + return nil, err + } + + return u, nil +} + +func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { + username := a.config.Host + + // check if instance account already exists + existsQ := a.conn. + NewSelect(). + Model(>smodel.Account{}). + Where("username = ?", username). + Where("? IS NULL", bun.Ident("domain")) + count, err := existsQ.Count(ctx) + if err != nil && count == 1 { + a.log.Infof("instance account %s already exists", username) + return nil + } else if err != sql.ErrNoRows { + return processErrorResponse(err) + } + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + a.log.Errorf("error creating new rsa key: %s", err) + return err + } + + aID, err := id.NewRandomULID() + if err != nil { + return err + } + + newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host) + acct := >smodel.Account{ + ID: aID, + Username: a.config.Host, + DisplayName: username, + URL: newAccountURIs.UserURL, + PrivateKey: key, + PublicKey: &key.PublicKey, + PublicKeyURI: newAccountURIs.PublicKeyURI, + ActorType: gtsmodel.ActivityStreamsPerson, + URI: newAccountURIs.UserURI, + InboxURI: newAccountURIs.InboxURI, + OutboxURI: newAccountURIs.OutboxURI, + FollowersURI: newAccountURIs.FollowersURI, + FollowingURI: newAccountURIs.FollowingURI, + FeaturedCollectionURI: newAccountURIs.CollectionURI, + } + + insertQ := a.conn. + NewInsert(). + Model(acct) + + if _, err := insertQ.Exec(ctx); err != nil { + return err + } + + a.log.Infof("instance account %s CREATED with id %s", username, acct.ID) + return nil +} + +func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { + domain := a.config.Host + + // check if instance entry already exists + existsQ := a.conn. + NewSelect(). + Model(>smodel.Instance{}). + Where("domain = ?", domain) + + count, err := existsQ.Count(ctx) + if err != nil && count == 1 { + a.log.Infof("instance instance %s already exists", domain) + return nil + } else if err != sql.ErrNoRows { + return processErrorResponse(err) + } + + iID, err := id.NewRandomULID() + if err != nil { + return err + } + + i := >smodel.Instance{ + ID: iID, + Domain: domain, + Title: domain, + URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host), + } + + insertQ := a.conn. + NewInsert(). + Model(i) + + if _, err := insertQ.Exec(ctx); err != nil { + return err + } + a.log.Infof("created instance instance %s with id %s", domain, i.ID) + return nil +} diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go new file mode 100644 index 000000000..983b6b810 --- /dev/null +++ b/internal/db/bundb/basic.go @@ -0,0 +1,179 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "errors" + "strings" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/uptrace/bun" +) + +type basicDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error { + _, err := b.conn.NewInsert().Model(i).Exec(ctx) + if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return db.ErrAlreadyExists + } + return err +} + +func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error { + q := b.conn. + NewSelect(). + Model(i). + Where("id = ?", id) + + return processErrorResponse(q.Scan(ctx)) +} + +func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { + if len(where) == 0 { + return errors.New("no queries provided") + } + + q := b.conn.NewSelect().Model(i) + for _, w := range where { + + if w.Value == nil { + q = q.Where("? IS NULL", bun.Ident(w.Key)) + } else { + if w.CaseInsensitive { + q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value) + } else { + q = q.Where("? = ?", bun.Safe(w.Key), w.Value) + } + } + } + + return processErrorResponse(q.Scan(ctx)) +} + +func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error { + q := b.conn. + NewSelect(). + Model(i) + + return processErrorResponse(q.Scan(ctx)) +} + +func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error { + q := b.conn. + NewDelete(). + Model(i). + Where("id = ?", id) + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { + if len(where) == 0 { + return errors.New("no queries provided") + } + + q := b.conn. + NewDelete(). + Model(i) + + for _, w := range where { + q = q.Where("? = ?", bun.Safe(w.Key), w.Value) + } + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error { + q := b.conn. + NewUpdate(). + Model(i). + WherePK() + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error { + q := b.conn.NewUpdate(). + Model(i). + Set("? = ?", bun.Safe(key), value). + WherePK() + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error { + q := b.conn.NewUpdate().Model(i) + + for _, w := range where { + if w.Value == nil { + q = q.Where("? IS NULL", bun.Ident(w.Key)) + } else { + if w.CaseInsensitive { + q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value) + } else { + q = q.Where("? = ?", bun.Safe(w.Key), w.Value) + } + } + } + + q = q.Set("? = ?", bun.Safe(key), value) + + _, err := q.Exec(ctx) + + return processErrorResponse(err) +} + +func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { + _, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx) + return err +} + +func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error { + _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx) + return processErrorResponse(err) +} + +func (b *basicDB) IsHealthy(ctx context.Context) db.Error { + return b.conn.Ping() +} + +func (b *basicDB) Stop(ctx context.Context) db.Error { + b.log.Info("closing db connection") + if err := b.conn.Close(); err != nil { + // only cancel if there's a problem closing the db + return err + } + return nil +} diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go new file mode 100644 index 000000000..9189618c9 --- /dev/null +++ b/internal/db/bundb/basic_test.go @@ -0,0 +1,68 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type BasicTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *BasicTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testTags = testrig.NewTestTags() + suite.testMentions = testrig.NewTestMentions() +} + +func (suite *BasicTestSuite) SetupTest() { + suite.config = testrig.NewTestConfig() + suite.db = testrig.NewTestDB() + suite.log = testrig.NewTestLog() + + testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *BasicTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *BasicTestSuite) TestGetAccountByID() { + testAccount := suite.testAccounts["local_account_1"] + + a := >smodel.Account{} + err := suite.db.GetByID(context.Background(), testAccount.ID, a) + suite.NoError(err) +} + +func TestBasicTestSuite(t *testing.T) { + suite.Run(t, new(BasicTestSuite)) +} diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go new file mode 100644 index 000000000..49ed09cbd --- /dev/null +++ b/internal/db/bundb/bundb.go @@ -0,0 +1,410 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "crypto/tls" + "crypto/x509" + "database/sql" + "encoding/pem" + "errors" + "fmt" + "os" + "strings" + "time" + + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/stdlib" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" +) + +const ( + dbTypePostgres = "postgres" + dbTypeSqlite = "sqlite" +) + +var registerTables []interface{} = []interface{}{ + >smodel.StatusToEmoji{}, + >smodel.StatusToTag{}, +} + +// bunDBService satisfies the DB interface +type bunDBService struct { + db.Account + db.Admin + db.Basic + db.Domain + db.Instance + db.Media + db.Mention + db.Notification + db.Relationship + db.Session + db.Status + db.Timeline + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. +// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. +func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { + var sqldb *sql.DB + var conn *bun.DB + + // depending on the database type we're trying to create, we need to use a different driver... + switch strings.ToLower(c.DBConfig.Type) { + case dbTypePostgres: + // POSTGRES + opts, err := deriveBunDBPGOptions(c) + if err != nil { + return nil, fmt.Errorf("could not create bundb postgres options: %s", err) + } + sqldb = stdlib.OpenDB(*opts) + conn = bun.NewDB(sqldb, pgdialect.New()) + case dbTypeSqlite: + // SQLITE + // TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite + default: + return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type)) + } + + // actually *begin* the connection so that we can tell if the db is there and listening + if err := conn.Ping(); err != nil { + return nil, fmt.Errorf("db connection error: %s", err) + } + log.Info("connected to database") + + for _, t := range registerTables { + // https://bun.uptrace.dev/orm/many-to-many-relation/ + conn.RegisterModel(t) + } + + ps := &bunDBService{ + Account: &accountDB{ + config: c, + conn: conn, + log: log, + }, + Admin: &adminDB{ + config: c, + conn: conn, + log: log, + }, + Basic: &basicDB{ + config: c, + conn: conn, + log: log, + }, + Domain: &domainDB{ + config: c, + conn: conn, + log: log, + }, + Instance: &instanceDB{ + config: c, + conn: conn, + log: log, + }, + Media: &mediaDB{ + config: c, + conn: conn, + log: log, + }, + Mention: &mentionDB{ + config: c, + conn: conn, + log: log, + }, + Notification: ¬ificationDB{ + config: c, + conn: conn, + log: log, + }, + Relationship: &relationshipDB{ + config: c, + conn: conn, + log: log, + }, + Session: &sessionDB{ + config: c, + conn: conn, + log: log, + }, + Status: &statusDB{ + config: c, + conn: conn, + log: log, + }, + Timeline: &timelineDB{ + config: c, + conn: conn, + log: log, + }, + config: c, + conn: conn, + log: log, + } + + // we can confidently return this useable service now + return ps, nil +} + +/* + HANDY STUFF +*/ + +// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options +// with sensible defaults, or an error if it's not satisfied by the provided config. +func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) { + if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres { + return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type) + } + + // validate port + if c.DBConfig.Port == 0 { + return nil, errors.New("no port set") + } + + // validate address + if c.DBConfig.Address == "" { + return nil, errors.New("no address set") + } + + // validate username + if c.DBConfig.User == "" { + return nil, errors.New("no user set") + } + + // validate that there's a password + if c.DBConfig.Password == "" { + return nil, errors.New("no password set") + } + + // validate database + if c.DBConfig.Database == "" { + return nil, errors.New("no database set") + } + + var tlsConfig *tls.Config + switch c.DBConfig.TLSMode { + case config.DBTLSModeDisable, config.DBTLSModeUnset: + break // nothing to do + case config.DBTLSModeEnable: + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + case config.DBTLSModeRequire: + tlsConfig = &tls.Config{ + InsecureSkipVerify: false, + ServerName: c.DBConfig.Address, + } + } + + if tlsConfig != nil && c.DBConfig.TLSCACert != "" { + // load the system cert pool first -- we'll append the given CA cert to this + certPool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("error fetching system CA cert pool: %s", err) + } + + // open the file itself and make sure there's something in it + caCertBytes, err := os.ReadFile(c.DBConfig.TLSCACert) + if err != nil { + return nil, fmt.Errorf("error opening CA certificate at %s: %s", c.DBConfig.TLSCACert, err) + } + if len(caCertBytes) == 0 { + return nil, fmt.Errorf("ca cert at %s was empty", c.DBConfig.TLSCACert) + } + + // make sure we have a PEM block + caPem, _ := pem.Decode(caCertBytes) + if caPem == nil { + return nil, fmt.Errorf("could not parse cert at %s into PEM", c.DBConfig.TLSCACert) + } + + // parse the PEM block into the certificate + caCert, err := x509.ParseCertificate(caPem.Bytes) + if err != nil { + return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", c.DBConfig.TLSCACert, err) + } + + // we're happy, add it to the existing pool and then use this pool in our tls config + certPool.AddCert(caCert) + tlsConfig.RootCAs = certPool + } + + cfg, _ := pgx.ParseConfig("") + cfg.Host = c.DBConfig.Address + cfg.Port = uint16(c.DBConfig.Port) + cfg.User = c.DBConfig.User + cfg.Password = c.DBConfig.Password + cfg.TLSConfig = tlsConfig + cfg.Database = c.DBConfig.Database + cfg.PreferSimpleProtocol = true + + return cfg, nil +} + +/* + CONVERSION FUNCTIONS +*/ + +// TODO: move these to the type converter, it's bananas that they're here and not there + +func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) { + ogAccount := >smodel.Account{} + if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil { + return nil, err + } + + menchies := []*gtsmodel.Mention{} + for _, a := range targetAccounts { + // A mentioned account looks like "@test@example.org" or just "@test" for a local account + // -- we can guarantee this from the regex that targetAccounts should have been derived from. + // But we still need to do a bit of fiddling to get what we need here -- the username and domain (if given). + + // 1. trim off the first @ + t := strings.TrimPrefix(a, "@") + + // 2. split the username and domain + s := strings.Split(t, "@") + + // 3. if it's length 1 it's a local account, length 2 means remote, anything else means something is wrong + var local bool + switch len(s) { + case 1: + local = true + case 2: + local = false + default: + return nil, fmt.Errorf("mentioned account format '%s' was not valid", a) + } + + var username, domain string + username = s[0] + if !local { + domain = s[1] + } + + // 4. check we now have a proper username and domain + if username == "" || (!local && domain == "") { + return nil, fmt.Errorf("username or domain for '%s' was nil", a) + } + + // okay we're good now, we can start pulling accounts out of the database + mentionedAccount := >smodel.Account{} + var err error + + // match username + account, case insensitive + if local { + // local user -- should have a null domain + err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx) + } else { + // remote user -- should have domain defined + err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx) + } + + if err != nil { + if err == sql.ErrNoRows { + // no result found for this username/domain so just don't include it as a mencho and carry on about our business + ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain) + continue + } + // a serious error has happened so bail + return nil, fmt.Errorf("error getting account with username '%s' and domain '%s': %s", username, domain, err) + } + + // id, createdAt and updatedAt will be populated by the db, so we have everything we need! + menchies = append(menchies, >smodel.Mention{ + StatusID: statusID, + OriginAccountID: ogAccount.ID, + OriginAccountURI: ogAccount.URI, + TargetAccountID: mentionedAccount.ID, + NameString: a, + TargetAccountURI: mentionedAccount.URI, + TargetAccountURL: mentionedAccount.URL, + OriginAccount: mentionedAccount, + }) + } + return menchies, nil +} + +func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) { + newTags := []*gtsmodel.Tag{} + for _, t := range tags { + tag := >smodel.Tag{} + // we can use selectorinsert here to create the new tag if it doesn't exist already + // inserted will be true if this is a new tag we just created + if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil { + if err == sql.ErrNoRows { + // tag doesn't exist yet so populate it + newID, err := id.NewRandomULID() + if err != nil { + return nil, err + } + tag.ID = newID + tag.URL = fmt.Sprintf("%s://%s/tags/%s", ps.config.Protocol, ps.config.Host, t) + tag.Name = t + tag.FirstSeenFromAccountID = originAccountID + tag.CreatedAt = time.Now() + tag.UpdatedAt = time.Now() + tag.Useable = true + tag.Listable = true + } else { + return nil, fmt.Errorf("error getting tag with name %s: %s", t, err) + } + } + + // bail already if the tag isn't useable + if !tag.Useable { + continue + } + tag.LastStatusAt = time.Now() + newTags = append(newTags, tag) + } + return newTags, nil +} + +func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) { + newEmojis := []*gtsmodel.Emoji{} + for _, e := range emojis { + emoji := >smodel.Emoji{} + err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + // no result found for this username/domain so just don't include it as an emoji and carry on about our business + ps.log.Debugf("no emoji found with shortcode %s, skipping it", e) + continue + } + // a serious error has happened so bail + return nil, fmt.Errorf("error getting emoji with shortcode %s: %s", e, err) + } + newEmojis = append(newEmojis, emoji) + } + return newEmojis, nil +} diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go new file mode 100644 index 000000000..b789375af --- /dev/null +++ b/internal/db/bundb/bundb_test.go @@ -0,0 +1,47 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +type BunDBStandardTestSuite struct { + // standard suite interfaces + suite.Suite + config *config.Config + db db.DB + log *logrus.Logger + + // standard suite models + testTokens map[string]*oauth.Token + testClients map[string]*oauth.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + testAttachments map[string]*gtsmodel.MediaAttachment + testStatuses map[string]*gtsmodel.Status + testTags map[string]*gtsmodel.Tag + testMentions map[string]*gtsmodel.Mention +} diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go new file mode 100644 index 000000000..6aa2b8ffe --- /dev/null +++ b/internal/db/bundb/domain.go @@ -0,0 +1,81 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "net/url" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" +) + +type domainDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { + if domain == "" { + return false, nil + } + + q := d.conn. + NewSelect(). + Model(>smodel.DomainBlock{}). + Where("LOWER(domain) = LOWER(?)", domain). + Limit(1) + + return exists(ctx, q) +} + +func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { + // filter out any doubles + uniqueDomains := util.UniqueStrings(domains) + + for _, domain := range uniqueDomains { + if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil { + return false, err + } else if blocked { + return blocked, nil + } + } + + // no blocks found + return false, nil +} + +func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) { + domain := uri.Hostname() + return d.IsDomainBlocked(ctx, domain) +} + +func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) { + domains := []string{} + for _, uri := range uris { + domains = append(domains, uri.Hostname()) + } + + return d.AreDomainsBlocked(ctx, domains) +} diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go new file mode 100644 index 000000000..f9364346e --- /dev/null +++ b/internal/db/bundb/instance.go @@ -0,0 +1,118 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type instanceDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { + q := i.conn. + NewSelect(). + Model(&[]*gtsmodel.Account{}) + + if domain == i.config.Host { + // if the domain is *this* domain, just count where the domain field is null + q = q.Where("? IS NULL", bun.Ident("domain")) + } else { + q = q.Where("domain = ?", domain) + } + + // don't count the instance account or suspended users + q = q. + Where("username != ?", domain). + Where("? IS NULL", bun.Ident("suspended_at")) + + count, err := q.Count(ctx) + + return count, processErrorResponse(err) +} + +func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { + q := i.conn. + NewSelect(). + Model(&[]*gtsmodel.Status{}) + + if domain == i.config.Host { + // if the domain is *this* domain, just count where local is true + q = q.Where("local = ?", true) + } else { + // join on the domain of the account + q = q.Join("JOIN accounts AS account ON account.id = status.account_id"). + Where("account.domain = ?", domain) + } + + count, err := q.Count(ctx) + + return count, processErrorResponse(err) +} + +func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { + q := i.conn. + NewSelect(). + Model(&[]*gtsmodel.Instance{}) + + if domain == i.config.Host { + // if the domain is *this* domain, just count other instances it knows about + // exclude domains that are blocked + q = q.Where("domain != ?", domain).Where("? IS NULL", bun.Ident("suspended_at")) + } else { + // TODO: implement federated domain counting properly for remote domains + return 0, nil + } + + count, err := q.Count(ctx) + + return count, processErrorResponse(err) +} + +func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { + i.log.Debug("GetAccountsForInstance") + + accounts := []*gtsmodel.Account{} + + q := i.conn.NewSelect(). + Model(&accounts). + Where("domain = ?", domain). + Order("id DESC") + + if maxID != "" { + q = q.Where("id < ?", maxID) + } + + if limit > 0 { + q = q.Limit(limit) + } + + err := processErrorResponse(q.Scan(ctx)) + + return accounts, err +} diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go new file mode 100644 index 000000000..04e55ca62 --- /dev/null +++ b/internal/db/bundb/media.go @@ -0,0 +1,53 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type mediaDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery { + return m.conn. + NewSelect(). + Model(i). + Relation("Account") +} + +func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) { + attachment := >smodel.MediaAttachment{} + + q := m.newMediaQ(attachment). + Where("media_attachment.id = ?", id) + + err := processErrorResponse(q.Scan(ctx)) + + return attachment, err +} diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go new file mode 100644 index 000000000..a444f9b5f --- /dev/null +++ b/internal/db/bundb/mention.go @@ -0,0 +1,108 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type mentionDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger + cache cache.Cache +} + +func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) { + if m.cache == nil { + m.cache = cache.New() + } + + if err := m.cache.Store(id, mention); err != nil { + m.log.Panicf("mentionDB: error storing in cache: %s", err) + } +} + +func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) { + if m.cache == nil { + m.cache = cache.New() + return nil, false + } + + mI, err := m.cache.Fetch(id) + if err != nil || mI == nil { + return nil, false + } + + mention, ok := mI.(*gtsmodel.Mention) + if !ok { + m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id) + } + + return mention, true +} + +func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery { + return m.conn. + NewSelect(). + Model(i). + Relation("Status"). + Relation("OriginAccount"). + Relation("TargetAccount") +} + +func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { + if mention, cached := m.mentionCached(id); cached { + return mention, nil + } + + mention := >smodel.Mention{} + + q := m.newMentionQ(mention). + Where("mention.id = ?", id) + + err := processErrorResponse(q.Scan(ctx)) + + if err == nil && mention != nil { + m.cacheMention(id, mention) + } + + return mention, err +} + +func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) { + mentions := []*gtsmodel.Mention{} + + for _, i := range ids { + mention, err := m.GetMention(ctx, i) + if err != nil { + return nil, processErrorResponse(err) + } + mentions = append(mentions, mention) + } + + return mentions, nil +} diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go new file mode 100644 index 000000000..1c30837ec --- /dev/null +++ b/internal/db/bundb/notification.go @@ -0,0 +1,136 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type notificationDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger + cache cache.Cache +} + +func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) { + if n.cache == nil { + n.cache = cache.New() + } + + if err := n.cache.Store(id, notification); err != nil { + n.log.Panicf("notificationDB: error storing in cache: %s", err) + } +} + +func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) { + if n.cache == nil { + n.cache = cache.New() + return nil, false + } + + nI, err := n.cache.Fetch(id) + if err != nil || nI == nil { + return nil, false + } + + notification, ok := nI.(*gtsmodel.Notification) + if !ok { + n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id) + } + + return notification, true +} + +func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery { + return n.conn. + NewSelect(). + Model(i). + Relation("OriginAccount"). + Relation("TargetAccount"). + Relation("Status") +} + +func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { + if notification, cached := n.notificationCached(id); cached { + return notification, nil + } + + notification := >smodel.Notification{} + + q := n.newNotificationQ(notification). + Where("notification.id = ?", id) + + err := processErrorResponse(q.Scan(ctx)) + + if err == nil && notification != nil { + n.cacheNotification(id, notification) + } + + return notification, err +} + +func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { + // begin by selecting just the IDs + notifIDs := []*gtsmodel.Notification{} + q := n.conn. + NewSelect(). + Model(¬ifIDs). + Column("id"). + Where("target_account_id = ?", accountID). + Order("id DESC") + + if maxID != "" { + q = q.Where("id < ?", maxID) + } + + if sinceID != "" { + q = q.Where("id > ?", sinceID) + } + + if limit != 0 { + q = q.Limit(limit) + } + + err := processErrorResponse(q.Scan(ctx)) + if err != nil { + return nil, err + } + + // now we have the IDs, select the notifs one by one + // reason for this is that for each notif, we can instead get it from our cache if it's cached + notifications := []*gtsmodel.Notification{} + for _, notifID := range notifIDs { + notif, err := n.GetNotification(ctx, notifID.ID) + errP := processErrorResponse(err) + if errP != nil { + return nil, errP + } + notifications = append(notifications, notif) + } + + return notifications, nil +} diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go new file mode 100644 index 000000000..ccc604baf --- /dev/null +++ b/internal/db/bundb/relationship.go @@ -0,0 +1,328 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "database/sql" + "fmt" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type relationshipDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery { + return r.conn. + NewSelect(). + Model(block). + Relation("Account"). + Relation("TargetAccount") +} + +func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery { + return r.conn. + NewSelect(). + Model(follow). + Relation("Account"). + Relation("TargetAccount") +} + +func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) { + q := r.conn. + NewSelect(). + Model(>smodel.Block{}). + Where("account_id = ?", account1). + Where("target_account_id = ?", account2). + Limit(1) + + if eitherDirection { + q = q. + WhereOr("target_account_id = ?", account1). + Where("account_id = ?", account2) + } + + return exists(ctx, q) +} + +func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) { + block := >smodel.Block{} + + q := r.newBlockQ(block). + Where("block.account_id = ?", account1). + Where("block.target_account_id = ?", account2) + + err := processErrorResponse(q.Scan(ctx)) + + return block, err +} + +func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { + rel := >smodel.Relationship{ + ID: targetAccount, + } + + // check if the requesting account follows the target account + follow := >smodel.Follow{} + if err := r.conn. + NewSelect(). + Model(follow). + Where("account_id = ?", requestingAccount). + Where("target_account_id = ?", targetAccount). + Limit(1). + Scan(ctx); err != nil { + if err != sql.ErrNoRows { + // a proper error + return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) + } + // no follow exists so these are all false + rel.Following = false + rel.ShowingReblogs = false + rel.Notifying = false + } else { + // follow exists so we can fill these fields out... + rel.Following = true + rel.ShowingReblogs = follow.ShowReblogs + rel.Notifying = follow.Notify + } + + // check if the target account follows the requesting account + count, err := r.conn. + NewSelect(). + Model(>smodel.Follow{}). + Where("account_id = ?", targetAccount). + Where("target_account_id = ?", requestingAccount). + Limit(1). + Count(ctx) + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) + } + rel.FollowedBy = count > 0 + + // check if the requesting account blocks the target account + count, err = r.conn.NewSelect(). + Model(>smodel.Block{}). + Where("account_id = ?", requestingAccount). + Where("target_account_id = ?", targetAccount). + Limit(1). + Count(ctx) + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) + } + rel.Blocking = count > 0 + + // check if the target account blocks the requesting account + count, err = r.conn. + NewSelect(). + Model(>smodel.Block{}). + Where("account_id = ?", targetAccount). + Where("target_account_id = ?", requestingAccount). + Limit(1). + Count(ctx) + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) + } + rel.BlockedBy = count > 0 + + // check if there's a pending following request from requesting account to target account + count, err = r.conn. + NewSelect(). + Model(>smodel.FollowRequest{}). + Where("account_id = ?", requestingAccount). + Where("target_account_id = ?", targetAccount). + Limit(1). + Count(ctx) + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) + } + rel.Requested = count > 0 + + return rel, nil +} + +func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { + if sourceAccount == nil || targetAccount == nil { + return false, nil + } + + q := r.conn. + NewSelect(). + Model(>smodel.Follow{}). + Where("account_id = ?", sourceAccount.ID). + Where("target_account_id = ?", targetAccount.ID). + Limit(1) + + return exists(ctx, q) +} + +func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { + if sourceAccount == nil || targetAccount == nil { + return false, nil + } + + q := r.conn. + NewSelect(). + Model(>smodel.FollowRequest{}). + Where("account_id = ?", sourceAccount.ID). + Where("target_account_id = ?", targetAccount.ID) + + return exists(ctx, q) +} + +func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { + if account1 == nil || account2 == nil { + return false, nil + } + + // make sure account 1 follows account 2 + f1, err := r.IsFollowing(ctx, account1, account2) + if err != nil { + return false, processErrorResponse(err) + } + + // make sure account 2 follows account 1 + f2, err := r.IsFollowing(ctx, account2, account1) + if err != nil { + return false, processErrorResponse(err) + } + + return f1 && f2, nil +} + +func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { + // make sure the original follow request exists + fr := >smodel.FollowRequest{} + if err := r.conn. + NewSelect(). + Model(fr). + Where("account_id = ?", originAccountID). + Where("target_account_id = ?", targetAccountID). + Scan(ctx); err != nil { + return nil, processErrorResponse(err) + } + + // create a new follow to 'replace' the request with + follow := >smodel.Follow{ + ID: fr.ID, + AccountID: originAccountID, + TargetAccountID: targetAccountID, + URI: fr.URI, + } + + // if the follow already exists, just update the URI -- we don't need to do anything else + if _, err := r.conn. + NewInsert(). + Model(follow). + On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI). + Exec(ctx); err != nil { + return nil, processErrorResponse(err) + } + + // now remove the follow request + if _, err := r.conn. + NewDelete(). + Model(>smodel.FollowRequest{}). + Where("account_id = ?", originAccountID). + Where("target_account_id = ?", targetAccountID). + Exec(ctx); err != nil { + return nil, processErrorResponse(err) + } + + return follow, nil +} + +func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) { + followRequests := []*gtsmodel.FollowRequest{} + + q := r.newFollowQ(&followRequests). + Where("target_account_id = ?", accountID) + + err := processErrorResponse(q.Scan(ctx)) + + return followRequests, err +} + +func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) { + follows := []*gtsmodel.Follow{} + + q := r.newFollowQ(&follows). + Where("account_id = ?", accountID) + + err := processErrorResponse(q.Scan(ctx)) + + return follows, err +} + +func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { + return r.conn. + NewSelect(). + Model(&[]*gtsmodel.Follow{}). + Where("account_id = ?", accountID). + Count(ctx) +} + +func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { + + follows := []*gtsmodel.Follow{} + + q := r.conn. + NewSelect(). + Model(&follows) + + if localOnly { + // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe + whereGroup := func(q *bun.SelectQuery) *bun.SelectQuery { + q = q. + WhereOr("? IS NULL", bun.Ident("a.domain")). + WhereOr("a.domain = ?", "") + return q + } + + q = q.ColumnExpr("follow.*"). + Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)"). + Where("follow.target_account_id = ?", accountID). + WhereGroup(" AND ", whereGroup) + } else { + q = q.Where("target_account_id = ?", accountID) + } + + if err := q.Scan(ctx); err != nil { + if err == sql.ErrNoRows { + return follows, nil + } + return nil, processErrorResponse(err) + } + return follows, nil +} + +func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { + return r.conn. + NewSelect(). + Model(&[]*gtsmodel.Follow{}). + Where("target_account_id = ?", accountID). + Count(ctx) +} diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go new file mode 100644 index 000000000..87e20673d --- /dev/null +++ b/internal/db/bundb/session.go @@ -0,0 +1,85 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "crypto/rand" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/uptrace/bun" +) + +type sessionDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { + rs := new(gtsmodel.RouterSession) + + q := s.conn. + NewSelect(). + Model(rs). + Limit(1) + + _, err := q.Exec(ctx) + + err = processErrorResponse(err) + + return rs, err +} + +func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { + auth := make([]byte, 32) + crypt := make([]byte, 32) + + if _, err := rand.Read(auth); err != nil { + return nil, err + } + if _, err := rand.Read(crypt); err != nil { + return nil, err + } + + rid, err := id.NewULID() + if err != nil { + return nil, err + } + + rs := >smodel.RouterSession{ + ID: rid, + Auth: auth, + Crypt: crypt, + } + + q := s.conn. + NewInsert(). + Model(rs) + + _, err = q.Exec(ctx) + + err = processErrorResponse(err) + + return rs, err +} diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go new file mode 100644 index 000000000..da8d8ca41 --- /dev/null +++ b/internal/db/bundb/status.go @@ -0,0 +1,375 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "container/list" + "context" + "errors" + "time" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type statusDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger + cache cache.Cache +} + +func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) { + if s.cache == nil { + s.cache = cache.New() + } + + if err := s.cache.Store(id, status); err != nil { + s.log.Panicf("statusDB: error storing in cache: %s", err) + } +} + +func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) { + if s.cache == nil { + s.cache = cache.New() + return nil, false + } + + sI, err := s.cache.Fetch(id) + if err != nil || sI == nil { + return nil, false + } + + status, ok := sI.(*gtsmodel.Status) + if !ok { + s.log.Panicf("statusDB: cached interface with key %s was not a status", id) + } + + return status, true +} + +func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { + return s.conn. + NewSelect(). + Model(status). + Relation("Attachments"). + Relation("Tags"). + Relation("Mentions"). + Relation("Emojis"). + Relation("Account"). + Relation("InReplyToAccount"). + Relation("BoostOfAccount"). + Relation("CreatedWithApplication") +} + +func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status { + if status.InReplyToID != "" && status.InReplyTo == nil { + if inReplyTo, cached := s.statusCached(status.InReplyToID); cached { + status.InReplyTo = inReplyTo + } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil { + status.InReplyTo = inReplyTo + } + } + + if status.BoostOfID != "" && status.BoostOf == nil { + if boostOf, cached := s.statusCached(status.BoostOfID); cached { + status.BoostOf = boostOf + } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil { + status.BoostOf = boostOf + } + } + + return status +} + +func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { + return s.conn. + NewSelect(). + Model(faves). + Relation("Account"). + Relation("TargetAccount"). + Relation("Status") +} + +func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(id); cached { + return status, nil + } + + status := new(gtsmodel.Status) + + q := s.newStatusQ(status). + Where("status.id = ?", id) + + err := processErrorResponse(q.Scan(ctx)) + + if err != nil { + return nil, err + } + + if status != nil { + s.cacheStatus(id, status) + } + + return s.getAttachedStatuses(ctx, status), err +} + +func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(uri); cached { + return status, nil + } + + status := >smodel.Status{} + + q := s.newStatusQ(status). + Where("LOWER(status.uri) = LOWER(?)", uri) + + err := processErrorResponse(q.Scan(ctx)) + + if err != nil { + return nil, err + } + + if status != nil { + s.cacheStatus(uri, status) + } + + return s.getAttachedStatuses(ctx, status), err +} + +func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { + if status, cached := s.statusCached(uri); cached { + return status, nil + } + + status := >smodel.Status{} + + q := s.newStatusQ(status). + Where("LOWER(status.url) = LOWER(?)", uri) + + err := processErrorResponse(q.Scan(ctx)) + + if err != nil { + return nil, err + } + + if status != nil { + s.cacheStatus(uri, status) + } + + return s.getAttachedStatuses(ctx, status), err +} + +func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { + transaction := func(ctx context.Context, tx bun.Tx) error { + // create links between this status and any emojis it uses + for _, i := range status.EmojiIDs { + if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ + StatusID: status.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + return err + } + } + + // create links between this status and any tags it uses + for _, i := range status.TagIDs { + if _, err := tx.NewInsert().Model(>smodel.StatusToTag{ + StatusID: status.ID, + TagID: i, + }).Exec(ctx); err != nil { + return err + } + } + + // change the status ID of the media attachments to the new status + for _, a := range status.Attachments { + a.StatusID = status.ID + a.UpdatedAt = time.Now() + if _, err := s.conn.NewUpdate().Model(a). + Where("id = ?", a.ID). + Exec(ctx); err != nil { + return err + } + } + + _, err := tx.NewInsert().Model(status).Exec(ctx) + return err + } + + return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction)) +} + +func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { + parents := []*gtsmodel.Status{} + s.statusParent(ctx, status, &parents, onlyDirect) + + return parents, nil +} + +func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { + if status.InReplyToID == "" { + return + } + + parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID) + if err == nil { + *foundStatuses = append(*foundStatuses, parentStatus) + } + + if onlyDirect { + return + } + + s.statusParent(ctx, parentStatus, foundStatuses, false) +} + +func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { + foundStatuses := &list.List{} + foundStatuses.PushFront(status) + s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID) + + children := []*gtsmodel.Status{} + for e := foundStatuses.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*gtsmodel.Status) + if !ok { + panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) + } + + // only append children, not the overall parent status + if entry.ID != status.ID { + children = append(children, entry) + } + } + + return children, nil +} + +func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { + immediateChildren := []*gtsmodel.Status{} + + q := s.conn. + NewSelect(). + Model(&immediateChildren). + Where("in_reply_to_id = ?", status.ID) + if minID != "" { + q = q.Where("status.id > ?", minID) + } + + if err := q.Scan(ctx); err != nil { + return + } + + for _, child := range immediateChildren { + insertLoop: + for e := foundStatuses.Front(); e != nil; e = e.Next() { + entry, ok := e.Value.(*gtsmodel.Status) + if !ok { + panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) + } + + if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { + foundStatuses.InsertAfter(child, e) + break insertLoop + } + } + + // only do one loop if we only want direct children + if onlyDirect { + return + } + s.statusChildren(ctx, child, foundStatuses, false, minID) + } +} + +func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { + return s.conn.NewSelect().Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx) +} + +func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { + return s.conn.NewSelect().Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx) +} + +func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { + return s.conn.NewSelect().Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx) +} + +func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.StatusFave{}). + Where("status_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.Status{}). + Where("boost_of_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.StatusMute{}). + Where("status_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { + q := s.conn. + NewSelect(). + Model(>smodel.StatusBookmark{}). + Where("status_id = ?", status.ID). + Where("account_id = ?", accountID) + + return exists(ctx, q) +} + +func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { + faves := []*gtsmodel.StatusFave{} + + q := s.newFaveQ(&faves). + Where("status_id = ?", status.ID) + + err := processErrorResponse(q.Scan(ctx)) + return faves, err +} + +func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { + reblogs := []*gtsmodel.Status{} + + q := s.newStatusQ(&reblogs). + Where("boost_of_id = ?", status.ID) + + err := processErrorResponse(q.Scan(ctx)) + return reblogs, err +} diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go new file mode 100644 index 000000000..513000577 --- /dev/null +++ b/internal/db/bundb/status_test.go @@ -0,0 +1,136 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type StatusTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *StatusTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testTags = testrig.NewTestTags() + suite.testMentions = testrig.NewTestMentions() +} + +func (suite *StatusTestSuite) SetupTest() { + suite.config = testrig.NewTestConfig() + suite.db = testrig.NewTestDB() + suite.log = testrig.NewTestLog() + + testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *StatusTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *StatusTestSuite) TestGetStatusByID() { + status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID) + if err != nil { + fmt.Println(err.Error()) + suite.FailNow(err.Error()) + } + suite.NotNil(status) + suite.NotNil(status.Account) + suite.NotNil(status.CreatedWithApplication) + suite.Nil(status.BoostOf) + suite.Nil(status.BoostOfAccount) + suite.Nil(status.InReplyTo) + suite.Nil(status.InReplyToAccount) +} + +func (suite *StatusTestSuite) TestGetStatusByURI() { + status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(status) + suite.NotNil(status.Account) + suite.NotNil(status.CreatedWithApplication) + suite.Nil(status.BoostOf) + suite.Nil(status.BoostOfAccount) + suite.Nil(status.InReplyTo) + suite.Nil(status.InReplyToAccount) +} + +func (suite *StatusTestSuite) TestGetStatusWithExtras() { + status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["admin_account_status_1"].ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(status) + suite.NotNil(status.Account) + suite.NotNil(status.CreatedWithApplication) + suite.NotEmpty(status.Tags) + suite.NotEmpty(status.Attachments) + suite.NotEmpty(status.Emojis) +} + +func (suite *StatusTestSuite) TestGetStatusWithMention() { + status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_2_status_5"].ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(status) + suite.NotNil(status.Account) + suite.NotNil(status.CreatedWithApplication) + suite.NotEmpty(status.Mentions) + suite.NotEmpty(status.MentionIDs) + suite.NotNil(status.InReplyTo) + suite.NotNil(status.InReplyToAccount) +} + +func (suite *StatusTestSuite) TestGetStatusTwice() { + before1 := time.Now() + _, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) + suite.NoError(err) + after1 := time.Now() + duration1 := after1.Sub(before1) + fmt.Println(duration1.Milliseconds()) + + before2 := time.Now() + _, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) + suite.NoError(err) + after2 := time.Now() + duration2 := after2.Sub(before2) + fmt.Println(duration2.Milliseconds()) + + // second retrieval should be several orders faster since it will be cached now + suite.Less(duration2, duration1) +} + +func TestStatusTestSuite(t *testing.T) { + suite.Run(t, new(StatusTestSuite)) +} diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go new file mode 100644 index 000000000..b62ad4c50 --- /dev/null +++ b/internal/db/bundb/timeline.go @@ -0,0 +1,199 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "database/sql" + "sort" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +type timelineDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { + statuses := []*gtsmodel.Status{} + q := t.conn. + NewSelect(). + Model(&statuses) + + q = q.ColumnExpr("status.*"). + // Find out who accountID follows. + Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id"). + // Sort by highest ID (newest) to lowest ID (oldest) + Order("status.id DESC") + + if maxID != "" { + // return only statuses LOWER (ie., older) than maxID + q = q.Where("status.id < ?", maxID) + } + + if sinceID != "" { + // return only statuses HIGHER (ie., newer) than sinceID + q = q.Where("status.id > ?", sinceID) + } + + if minID != "" { + // return only statuses HIGHER (ie., newer) than minID + q = q.Where("status.id > ?", minID) + } + + if local { + // return only statuses posted by local account havers + q = q.Where("status.local = ?", local) + } + + if limit > 0 { + // limit amount of statuses returned + q = q.Limit(limit) + } + + // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows, + // OR statuses posted by accountID itself (since a user should be able to see their own statuses). + // + // This is equivalent to something like WHERE ... AND (... OR ...) + // See: https://bun.uptrace.dev/guide/queries.html#select + whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { + return q. + WhereOr("f.account_id = ?", accountID). + WhereOr("status.account_id = ?", accountID) + } + + q = q.WhereGroup(" AND ", whereGroup) + + return statuses, processErrorResponse(q.Scan(ctx)) +} + +func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { + statuses := []*gtsmodel.Status{} + + q := t.conn. + NewSelect(). + Model(&statuses). + Where("visibility = ?", gtsmodel.VisibilityPublic). + Where("? IS NULL", bun.Ident("in_reply_to_id")). + Where("? IS NULL", bun.Ident("in_reply_to_uri")). + Where("? IS NULL", bun.Ident("boost_of_id")). + Order("status.id DESC") + + if maxID != "" { + q = q.Where("status.id < ?", maxID) + } + + if sinceID != "" { + q = q.Where("status.id > ?", sinceID) + } + + if minID != "" { + q = q.Where("status.id > ?", minID) + } + + if local { + q = q.Where("status.local = ?", local) + } + + if limit > 0 { + q = q.Limit(limit) + } + + return statuses, processErrorResponse(q.Scan(ctx)) +} + +// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! +// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds. +func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) { + + faves := []*gtsmodel.StatusFave{} + + fq := t.conn. + NewSelect(). + Model(&faves). + Where("account_id = ?", accountID). + Order("id DESC") + + if maxID != "" { + fq = fq.Where("id < ?", maxID) + } + + if minID != "" { + fq = fq.Where("id > ?", minID) + } + + if limit > 0 { + fq = fq.Limit(limit) + } + + err := fq.Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil, "", "", db.ErrNoEntries + } + return nil, "", "", err + } + + if len(faves) == 0 { + return nil, "", "", db.ErrNoEntries + } + + // map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID + statusesFavesMap := map[string]string{} + + in := []string{} + for _, f := range faves { + statusesFavesMap[f.StatusID] = f.ID + in = append(in, f.StatusID) + } + + statuses := []*gtsmodel.Status{} + err = t.conn. + NewSelect(). + Model(&statuses). + Where("id IN (?)", bun.In(in)). + Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + return nil, "", "", db.ErrNoEntries + } + return nil, "", "", err + } + + if len(statuses) == 0 { + return nil, "", "", db.ErrNoEntries + } + + // arrange statuses by fave ID + sort.Slice(statuses, func(i int, j int) bool { + statusI := statuses[i] + statusJ := statuses[j] + return statusesFavesMap[statusI.ID] < statusesFavesMap[statusJ.ID] + }) + + nextMaxID := faves[len(faves)-1].ID + prevMinID := faves[0].ID + return statuses, nextMaxID, prevMinID, nil +} diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go new file mode 100644 index 000000000..115d18de2 --- /dev/null +++ b/internal/db/bundb/util.go @@ -0,0 +1,78 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "strings" + + "database/sql" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/uptrace/bun" +) + +// processErrorResponse parses the given error and returns an appropriate DBError. +func processErrorResponse(err error) db.Error { + switch err { + case nil: + return nil + case sql.ErrNoRows: + return db.ErrNoEntries + default: + if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return db.ErrAlreadyExists + } + return err + } +} + +func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) { + count, err := q.Count(ctx) + + exists := count != 0 + + err = processErrorResponse(err) + + if err != nil { + if err == db.ErrNoEntries { + return false, nil + } + return false, err + } + + return exists, nil +} + +func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) { + count, err := q.Count(ctx) + + notExists := count == 0 + + err = processErrorResponse(err) + + if err != nil { + if err == db.ErrNoEntries { + return true, nil + } + return false, err + } + + return notExists, nil +} |