diff options
Diffstat (limited to 'internal/db/pg')
| -rw-r--r-- | internal/db/pg/account.go | 256 | ||||
| -rw-r--r-- | internal/db/pg/account_test.go | 70 | ||||
| -rw-r--r-- | internal/db/pg/admin.go | 235 | ||||
| -rw-r--r-- | internal/db/pg/basic.go | 205 | ||||
| -rw-r--r-- | internal/db/pg/domain.go | 83 | ||||
| -rw-r--r-- | internal/db/pg/instance.go | 112 | ||||
| -rw-r--r-- | internal/db/pg/media.go | 53 | ||||
| -rw-r--r-- | internal/db/pg/mention.go | 108 | ||||
| -rw-r--r-- | internal/db/pg/notification.go | 135 | ||||
| -rw-r--r-- | internal/db/pg/pg.go | 420 | ||||
| -rw-r--r-- | internal/db/pg/pg_test.go | 47 | ||||
| -rw-r--r-- | internal/db/pg/relationship.go | 276 | ||||
| -rw-r--r-- | internal/db/pg/status.go | 318 | ||||
| -rw-r--r-- | internal/db/pg/status_test.go | 134 | ||||
| -rw-r--r-- | internal/db/pg/timeline.go | 210 | ||||
| -rw-r--r-- | internal/db/pg/util.go | 25 |
16 files changed, 0 insertions, 2687 deletions
diff --git a/internal/db/pg/account.go b/internal/db/pg/account.go deleted file mode 100644 index 3889c6601..000000000 --- a/internal/db/pg/account.go +++ /dev/null @@ -1,256 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - "errors" - "fmt" - "time" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -type accountDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc -} - -func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query { - return a.conn.Model(account). - Relation("AvatarMediaAttachment"). - Relation("HeaderMediaAttachment") -} - -func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} - - q := a.newAccountQ(account). - Where("account.id = ?", id) - - err := processErrorResponse(q.Select()) - - return account, err -} - -func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} - - q := a.newAccountQ(account). - Where("account.uri = ?", uri) - - err := processErrorResponse(q.Select()) - - return account, err -} - -func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} - - q := a.newAccountQ(account). - Where("account.url = ?", uri) - - err := processErrorResponse(q.Select()) - - return account, err -} - -func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} - - q := a.newAccountQ(account) - - if domain == "" { - q = q. - Where("account.username = ?", domain). - Where("account.domain = ?", domain) - } else { - q = q. - Where("account.username = ?", domain). - Where("? IS NULL", pg.Ident("domain")) - } - - err := processErrorResponse(q.Select()) - - return account, err -} - -func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) { - status := >smodel.Status{} - - q := a.conn.Model(status). - Order("id DESC"). - Limit(1). - Where("account_id = ?", accountID). - Column("created_at") - - err := processErrorResponse(q.Select()) - - return status.CreatedAt, err -} - -func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { - if mediaAttachment.Avatar && mediaAttachment.Header { - return errors.New("one media attachment cannot be both header and avatar") - } - - var headerOrAVI string - if mediaAttachment.Avatar { - headerOrAVI = "avatar" - } else if mediaAttachment.Header { - headerOrAVI = "header" - } else { - return errors.New("given media attachment was neither a header nor an avatar") - } - - // TODO: there are probably more side effects here that need to be handled - if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil { - return err - } - - if _, err := a.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil { - return err - } - return nil -} - -func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) { - account := >smodel.Account{} - - q := a.newAccountQ(account). - Where("username = ?", username). - Where("? IS NULL", pg.Ident("domain")) - - err := processErrorResponse(q.Select()) - - return account, err -} - -func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) { - faves := []*gtsmodel.StatusFave{} - - if err := a.conn.Model(&faves). - Where("account_id = ?", accountID). - Select(); err != nil { - if err == pg.ErrNoRows { - return faves, nil - } - return nil, err - } - return faves, nil -} - -func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) { - return a.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count() -} - -func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) { - a.log.Debugf("getting statuses for account %s", accountID) - statuses := []*gtsmodel.Status{} - - q := a.conn.Model(&statuses).Order("id DESC") - if accountID != "" { - q = q.Where("account_id = ?", accountID) - } - - if limit != 0 { - q = q.Limit(limit) - } - - if excludeReplies { - q = q.Where("? IS NULL", pg.Ident("in_reply_to_id")) - } - - if pinnedOnly { - q = q.Where("pinned = ?", true) - } - - if mediaOnly { - q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) { - return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil - }) - } - - if maxID != "" { - q = q.Where("id < ?", maxID) - } - - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } - return nil, err - } - - if len(statuses) == 0 { - return nil, db.ErrNoEntries - } - - a.log.Debugf("returning statuses for account %s", accountID) - return statuses, nil -} - -func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { - blocks := []*gtsmodel.Block{} - - fq := a.conn.Model(&blocks). - Where("block.account_id = ?", accountID). - Relation("TargetAccount"). - Order("block.id DESC") - - if maxID != "" { - fq = fq.Where("block.id < ?", maxID) - } - - if sinceID != "" { - fq = fq.Where("block.id > ?", sinceID) - } - - if limit > 0 { - fq = fq.Limit(limit) - } - - err := fq.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries - } - return nil, "", "", err - } - - if len(blocks) == 0 { - return nil, "", "", db.ErrNoEntries - } - - accounts := []*gtsmodel.Account{} - for _, b := range blocks { - accounts = append(accounts, b.TargetAccount) - } - - nextMaxID := blocks[len(blocks)-1].ID - prevMinID := blocks[0].ID - return accounts, nextMaxID, prevMinID, nil -} diff --git a/internal/db/pg/account_test.go b/internal/db/pg/account_test.go deleted file mode 100644 index 7ea5ff39a..000000000 --- a/internal/db/pg/account_test.go +++ /dev/null @@ -1,70 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg_test - -import ( - "testing" - - "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/testrig" -) - -type AccountTestSuite struct { - PGStandardTestSuite -} - -func (suite *AccountTestSuite) SetupSuite() { - suite.testTokens = testrig.NewTestTokens() - suite.testClients = testrig.NewTestClients() - suite.testApplications = testrig.NewTestApplications() - suite.testUsers = testrig.NewTestUsers() - suite.testAccounts = testrig.NewTestAccounts() - suite.testAttachments = testrig.NewTestAttachments() - suite.testStatuses = testrig.NewTestStatuses() - suite.testTags = testrig.NewTestTags() - suite.testMentions = testrig.NewTestMentions() -} - -func (suite *AccountTestSuite) SetupTest() { - suite.config = testrig.NewTestConfig() - suite.db = testrig.NewTestDB() - suite.log = testrig.NewTestLog() - - testrig.StandardDBSetup(suite.db, suite.testAccounts) -} - -func (suite *AccountTestSuite) TearDownTest() { - testrig.StandardDBTeardown(suite.db) -} - -func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { - account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID) - if err != nil { - suite.FailNow(err.Error()) - } - suite.NotNil(account) - suite.NotNil(account.AvatarMediaAttachment) - suite.NotEmpty(account.AvatarMediaAttachment.URL) - suite.NotNil(account.HeaderMediaAttachment) - suite.NotEmpty(account.HeaderMediaAttachment.URL) -} - -func TestAccountTestSuite(t *testing.T) { - suite.Run(t, new(AccountTestSuite)) -} diff --git a/internal/db/pg/admin.go b/internal/db/pg/admin.go deleted file mode 100644 index 854f56ef0..000000000 --- a/internal/db/pg/admin.go +++ /dev/null @@ -1,235 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - "crypto/rand" - "crypto/rsa" - "fmt" - "net" - "net/mail" - "strings" - "time" - - "github.com/go-pg/pg/v10" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/id" - "github.com/superseriousbusiness/gotosocial/internal/util" - "golang.org/x/crypto/bcrypt" -) - -type adminDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc -} - -func (a *adminDB) IsUsernameAvailable(username string) db.Error { - // if no error we fail because it means we found something - // if error but it's not pg.ErrNoRows then we fail - // if err is pg.ErrNoRows we're good, we found nothing so continue - if err := a.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil { - return fmt.Errorf("username %s already in use", username) - } else if err != pg.ErrNoRows { - return fmt.Errorf("db error: %s", err) - } - return nil -} - -func (a *adminDB) IsEmailAvailable(email string) db.Error { - // parse the domain from the email - m, err := mail.ParseAddress(email) - if err != nil { - return fmt.Errorf("error parsing email address %s: %s", email, err) - } - domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ - - // check if the email domain is blocked - if err := a.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil { - // fail because we found something - return fmt.Errorf("email domain %s is blocked", domain) - } else if err != pg.ErrNoRows { - // fail because we got an unexpected error - return fmt.Errorf("db error: %s", err) - } - - // check if this email is associated with a user already - if err := a.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil { - // fail because we found something - return fmt.Errorf("email %s already in use", email) - } else if err != pg.ErrNoRows { - // fail because we got an unexpected error - return fmt.Errorf("db error: %s", err) - } - return nil -} - -func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - a.log.Errorf("error creating new rsa key: %s", err) - return nil, err - } - - // if something went wrong while creating a user, we might already have an account, so check here first... - acct := >smodel.Account{} - err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select() - if err != nil { - // there's been an actual error - if err != pg.ErrNoRows { - return nil, fmt.Errorf("db error checking existence of account: %s", err) - } - - // we just don't have an account yet create one - newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host) - newAccountID, err := id.NewRandomULID() - if err != nil { - return nil, err - } - - acct = >smodel.Account{ - ID: newAccountID, - Username: username, - DisplayName: username, - Reason: reason, - URL: newAccountURIs.UserURL, - PrivateKey: key, - PublicKey: &key.PublicKey, - PublicKeyURI: newAccountURIs.PublicKeyURI, - ActorType: gtsmodel.ActivityStreamsPerson, - URI: newAccountURIs.UserURI, - InboxURI: newAccountURIs.InboxURI, - OutboxURI: newAccountURIs.OutboxURI, - FollowersURI: newAccountURIs.FollowersURI, - FollowingURI: newAccountURIs.FollowingURI, - FeaturedCollectionURI: newAccountURIs.CollectionURI, - } - if _, err = a.conn.Model(acct).Insert(); err != nil { - return nil, err - } - } - - pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return nil, fmt.Errorf("error hashing password: %s", err) - } - - newUserID, err := id.NewRandomULID() - if err != nil { - return nil, err - } - - u := >smodel.User{ - ID: newUserID, - AccountID: acct.ID, - EncryptedPassword: string(pw), - SignUpIP: signUpIP.To4(), - Locale: locale, - UnconfirmedEmail: email, - CreatedByApplicationID: appID, - Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user - } - - if emailVerified { - u.ConfirmedAt = time.Now() - u.Email = email - } - - if admin { - u.Admin = true - u.Moderator = true - } - - if _, err = a.conn.Model(u).Insert(); err != nil { - return nil, err - } - - return u, nil -} - -func (a *adminDB) CreateInstanceAccount() db.Error { - username := a.config.Host - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - a.log.Errorf("error creating new rsa key: %s", err) - return err - } - - aID, err := id.NewRandomULID() - if err != nil { - return err - } - - newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host) - acct := >smodel.Account{ - ID: aID, - Username: a.config.Host, - DisplayName: username, - URL: newAccountURIs.UserURL, - PrivateKey: key, - PublicKey: &key.PublicKey, - PublicKeyURI: newAccountURIs.PublicKeyURI, - ActorType: gtsmodel.ActivityStreamsPerson, - URI: newAccountURIs.UserURI, - InboxURI: newAccountURIs.InboxURI, - OutboxURI: newAccountURIs.OutboxURI, - FollowersURI: newAccountURIs.FollowersURI, - FollowingURI: newAccountURIs.FollowingURI, - FeaturedCollectionURI: newAccountURIs.CollectionURI, - } - inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert() - if err != nil { - return err - } - if inserted { - a.log.Infof("created instance account %s with id %s", username, acct.ID) - } else { - a.log.Infof("instance account %s already exists with id %s", username, acct.ID) - } - return nil -} - -func (a *adminDB) CreateInstanceInstance() db.Error { - iID, err := id.NewRandomULID() - if err != nil { - return err - } - - i := >smodel.Instance{ - ID: iID, - Domain: a.config.Host, - Title: a.config.Host, - URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host), - } - inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert() - if err != nil { - return err - } - if inserted { - a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID) - } else { - a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID) - } - return nil -} diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go deleted file mode 100644 index 6e76b4450..000000000 --- a/internal/db/pg/basic.go +++ /dev/null @@ -1,205 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" -) - -type basicDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc -} - -func (b *basicDB) Put(i interface{}) db.Error { - _, err := b.conn.Model(i).Insert(i) - if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists - } - return err -} - -func (b *basicDB) GetByID(id string, i interface{}) db.Error { - if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - - } - return nil -} - -func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error { - if len(where) == 0 { - return errors.New("no queries provided") - } - - q := b.conn.Model(i) - for _, w := range where { - - if w.Value == nil { - q = q.Where("? IS NULL", pg.Ident(w.Key)) - } else { - if w.CaseInsensitive { - q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) - } else { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - } - } - - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) GetAll(i interface{}) db.Error { - if err := b.conn.Model(i).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) DeleteByID(id string, i interface{}) db.Error { - if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil { - // if there are no rows *anyway* then that's fine - // just return err if there's an actual error - if err != pg.ErrNoRows { - return err - } - } - return nil -} - -func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error { - if len(where) == 0 { - return errors.New("no queries provided") - } - - q := b.conn.Model(i) - for _, w := range where { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - - if _, err := q.Delete(); err != nil { - // if there are no rows *anyway* then that's fine - // just return err if there's an actual error - if err != pg.ErrNoRows { - return err - } - } - return nil -} - -func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error { - if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) UpdateByID(id string, i interface{}) db.Error { - if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries - } - return err - } - return nil -} - -func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error { - _, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update() - return err -} - -func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error { - q := b.conn.Model(i) - - for _, w := range where { - if w.Value == nil { - q = q.Where("? IS NULL", pg.Ident(w.Key)) - } else { - if w.CaseInsensitive { - q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) - } else { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - } - } - - q = q.Set("? = ?", pg.Safe(key), value) - - _, err := q.Update() - - return err -} - -func (b *basicDB) CreateTable(i interface{}) db.Error { - return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }) -} - -func (b *basicDB) DropTable(i interface{}) db.Error { - return b.conn.Model(i).DropTable(&orm.DropTableOptions{ - IfExists: true, - }) -} - -func (b *basicDB) RegisterTable(i interface{}) db.Error { - orm.RegisterTable(i) - return nil -} - -func (b *basicDB) IsHealthy(ctx context.Context) db.Error { - return b.conn.Ping(ctx) -} - -func (b *basicDB) Stop(ctx context.Context) db.Error { - b.log.Info("closing db connection") - if err := b.conn.Close(); err != nil { - // only cancel if there's a problem closing the db - b.cancel() - return err - } - return nil -} diff --git a/internal/db/pg/domain.go b/internal/db/pg/domain.go deleted file mode 100644 index 4e9b2ab48..000000000 --- a/internal/db/pg/domain.go +++ /dev/null @@ -1,83 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - "net/url" - - "github.com/go-pg/pg/v10" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/util" -) - -type domainDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc -} - -func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) { - if domain == "" { - return false, nil - } - - blocked, err := d.conn. - Model(>smodel.DomainBlock{}). - Where("LOWER(domain) = LOWER(?)", domain). - Exists() - - err = processErrorResponse(err) - - return blocked, err -} - -func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) { - // filter out any doubles - uniqueDomains := util.UniqueStrings(domains) - - for _, domain := range uniqueDomains { - if blocked, err := d.IsDomainBlocked(domain); err != nil { - return false, err - } else if blocked { - return blocked, nil - } - } - - // no blocks found - return false, nil -} - -func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) { - domain := uri.Hostname() - return d.IsDomainBlocked(domain) -} - -func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) { - domains := []string{} - for _, uri := range uris { - domains = append(domains, uri.Hostname()) - } - - return d.AreDomainsBlocked(domains) -} diff --git a/internal/db/pg/instance.go b/internal/db/pg/instance.go deleted file mode 100644 index 968832ca5..000000000 --- a/internal/db/pg/instance.go +++ /dev/null @@ -1,112 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - - "github.com/go-pg/pg/v10" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -type instanceDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc -} - -func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) { - q := i.conn.Model(&[]*gtsmodel.Account{}) - - if domain == i.config.Host { - // if the domain is *this* domain, just count where the domain field is null - q = q.Where("? IS NULL", pg.Ident("domain")) - } else { - q = q.Where("domain = ?", domain) - } - - // don't count the instance account or suspended users - q = q.Where("username != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) - - return q.Count() -} - -func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) { - q := i.conn.Model(&[]*gtsmodel.Status{}) - - if domain == i.config.Host { - // if the domain is *this* domain, just count where local is true - q = q.Where("local = ?", true) - } else { - // join on the domain of the account - q = q.Join("JOIN accounts AS account ON account.id = status.account_id"). - Where("account.domain = ?", domain) - } - - return q.Count() -} - -func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) { - q := i.conn.Model(&[]*gtsmodel.Instance{}) - - if domain == i.config.Host { - // if the domain is *this* domain, just count other instances it knows about - // exclude domains that are blocked - q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) - } else { - // TODO: implement federated domain counting properly for remote domains - return 0, nil - } - - return q.Count() -} - -func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { - i.log.Debug("GetAccountsForInstance") - - accounts := []*gtsmodel.Account{} - - q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC") - - if maxID != "" { - q = q.Where("id < ?", maxID) - } - - if limit > 0 { - q = q.Limit(limit) - } - - err := q.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } - return nil, err - } - - if len(accounts) == 0 { - return nil, db.ErrNoEntries - } - - return accounts, nil -} diff --git a/internal/db/pg/media.go b/internal/db/pg/media.go deleted file mode 100644 index 618030af3..000000000 --- a/internal/db/pg/media.go +++ /dev/null @@ -1,53 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -type mediaDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc -} - -func (m *mediaDB) newMediaQ(i interface{}) *orm.Query { - return m.conn.Model(i). - Relation("Account") -} - -func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) { - attachment := >smodel.MediaAttachment{} - - q := m.newMediaQ(attachment). - Where("media_attachment.id = ?", id) - - err := processErrorResponse(q.Select()) - - return attachment, err -} diff --git a/internal/db/pg/mention.go b/internal/db/pg/mention.go deleted file mode 100644 index b31f07b67..000000000 --- a/internal/db/pg/mention.go +++ /dev/null @@ -1,108 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/cache" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -type mentionDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc - cache cache.Cache -} - -func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) { - if m.cache == nil { - m.cache = cache.New() - } - - if err := m.cache.Store(id, mention); err != nil { - m.log.Panicf("mentionDB: error storing in cache: %s", err) - } -} - -func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) { - if m.cache == nil { - m.cache = cache.New() - return nil, false - } - - mI, err := m.cache.Fetch(id) - if err != nil || mI == nil { - return nil, false - } - - mention, ok := mI.(*gtsmodel.Mention) - if !ok { - m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id) - } - - return mention, true -} - -func (m *mentionDB) newMentionQ(i interface{}) *orm.Query { - return m.conn.Model(i). - Relation("Status"). - Relation("OriginAccount"). - Relation("TargetAccount") -} - -func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) { - if mention, cached := m.mentionCached(id); cached { - return mention, nil - } - - mention := >smodel.Mention{} - - q := m.newMentionQ(mention). - Where("mention.id = ?", id) - - err := processErrorResponse(q.Select()) - - if err == nil && mention != nil { - m.cacheMention(id, mention) - } - - return mention, err -} - -func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) { - mentions := []*gtsmodel.Mention{} - - for _, i := range ids { - mention, err := m.GetMention(i) - if err != nil { - return nil, processErrorResponse(err) - } - mentions = append(mentions, mention) - } - - return mentions, nil -} diff --git a/internal/db/pg/notification.go b/internal/db/pg/notification.go deleted file mode 100644 index 281a76d85..000000000 --- a/internal/db/pg/notification.go +++ /dev/null @@ -1,135 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/cache" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -type notificationDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc - cache cache.Cache -} - -func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) { - if n.cache == nil { - n.cache = cache.New() - } - - if err := n.cache.Store(id, notification); err != nil { - n.log.Panicf("notificationDB: error storing in cache: %s", err) - } -} - -func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) { - if n.cache == nil { - n.cache = cache.New() - return nil, false - } - - nI, err := n.cache.Fetch(id) - if err != nil || nI == nil { - return nil, false - } - - notification, ok := nI.(*gtsmodel.Notification) - if !ok { - n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id) - } - - return notification, true -} - -func (n *notificationDB) newNotificationQ(i interface{}) *orm.Query { - return n.conn.Model(i). - Relation("OriginAccount"). - Relation("TargetAccount"). - Relation("Status") -} - -func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.Error) { - if notification, cached := n.notificationCached(id); cached { - return notification, nil - } - - notification := >smodel.Notification{} - - q := n.newNotificationQ(notification). - Where("notification.id = ?", id) - - err := processErrorResponse(q.Select()) - - if err == nil && notification != nil { - n.cacheNotification(id, notification) - } - - return notification, err -} - -func (n *notificationDB) GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) { - // begin by selecting just the IDs - notifIDs := []*gtsmodel.Notification{} - q := n.conn. - Model(¬ifIDs). - Column("id"). - Where("target_account_id = ?", accountID). - Order("id DESC") - - if maxID != "" { - q = q.Where("id < ?", maxID) - } - - if sinceID != "" { - q = q.Where("id > ?", sinceID) - } - - if limit != 0 { - q = q.Limit(limit) - } - - err := processErrorResponse(q.Select()) - if err != nil { - return nil, err - } - - // now we have the IDs, select the notifs one by one - // reason for this is that for each notif, we can instead get it from our cache if it's cached - notifications := []*gtsmodel.Notification{} - for _, notifID := range notifIDs { - notif, err := n.GetNotification(notifID.ID) - errP := processErrorResponse(err) - if errP != nil { - return nil, errP - } - notifications = append(notifications, notif) - } - - return notifications, nil -} diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go deleted file mode 100644 index 0437baf02..000000000 --- a/internal/db/pg/pg.go +++ /dev/null @@ -1,420 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - "crypto/tls" - "crypto/x509" - "encoding/pem" - "errors" - "fmt" - "os" - "strings" - "time" - - "github.com/go-pg/pg/extra/pgdebug" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/id" -) - -var registerTables []interface{} = []interface{}{ - >smodel.StatusToEmoji{}, - >smodel.StatusToTag{}, -} - -// postgresService satisfies the DB interface -type postgresService struct { - db.Account - db.Admin - db.Basic - db.Domain - db.Instance - db.Media - db.Mention - db.Notification - db.Relationship - db.Status - db.Timeline - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc -} - -// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. -// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection. -func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { - for _, t := range registerTables { - // https://pg.uptrace.dev/orm/many-to-many-relation/ - orm.RegisterTable(t) - } - - opts, err := derivePGOptions(c) - if err != nil { - return nil, fmt.Errorf("could not create postgres service: %s", err) - } - log.Debugf("using pg options: %+v", opts) - - // create a connection - pgCtx, cancel := context.WithCancel(ctx) - conn := pg.Connect(opts).WithContext(pgCtx) - - // this will break the logfmt format we normally log in, - // since we can't choose where pg outputs to and it defaults to - // stdout. So use this option with care! - if log.GetLevel() >= logrus.TraceLevel { - conn.AddQueryHook(pgdebug.DebugHook{ - // Print all queries. - Verbose: true, - }) - } - - // actually *begin* the connection so that we can tell if the db is there and listening - if err := conn.Ping(ctx); err != nil { - cancel() - return nil, fmt.Errorf("db connection error: %s", err) - } - - // print out discovered postgres version - var version string - if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil { - cancel() - return nil, fmt.Errorf("db connection error: %s", err) - } - log.Infof("connected to postgres version: %s", version) - - ps := &postgresService{ - Account: &accountDB{ - config: c, - conn: conn, - log: log, - cancel: cancel, - }, - Admin: &adminDB{ - config: c, - conn: conn, - log: log, - cancel: cancel, - }, - Basic: &basicDB{ - config: c, - conn: conn, - log: log, - cancel: cancel, - }, - Domain: &domainDB{ - config: c, - conn: conn, - log: log, - cancel: cancel, - }, - Instance: &instanceDB{ - config: c, - conn: conn, - log: log, - cancel: cancel, - }, - Media: &mediaDB{ - config: c, - conn: conn, - log: log, - cancel: cancel, - }, - Mention: &mentionDB{ - config: c, - conn: conn, - log: log, - cancel: cancel, - }, - Notification: ¬ificationDB{ - config: c, - conn: conn, - log: log, - cancel: cancel, - }, - Relationship: &relationshipDB{ - config: c, - conn: conn, - log: log, - cancel: cancel, - }, - Status: &statusDB{ - config: c, - conn: conn, - log: log, - cancel: cancel, - }, - Timeline: &timelineDB{ - config: c, - conn: conn, - log: log, - cancel: cancel, - }, - config: c, - conn: conn, - log: log, - cancel: cancel, - } - - // we can confidently return this useable postgres service now - return ps, nil -} - -/* - HANDY STUFF -*/ - -// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options -// with sensible defaults, or an error if it's not satisfied by the provided config. -func derivePGOptions(c *config.Config) (*pg.Options, error) { - if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres { - return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type) - } - - // validate port - if c.DBConfig.Port == 0 { - return nil, errors.New("no port set") - } - - // validate address - if c.DBConfig.Address == "" { - return nil, errors.New("no address set") - } - - // validate username - if c.DBConfig.User == "" { - return nil, errors.New("no user set") - } - - // validate that there's a password - if c.DBConfig.Password == "" { - return nil, errors.New("no password set") - } - - // validate database - if c.DBConfig.Database == "" { - return nil, errors.New("no database set") - } - - var tlsConfig *tls.Config - switch c.DBConfig.TLSMode { - case config.DBTLSModeDisable, config.DBTLSModeUnset: - break // nothing to do - case config.DBTLSModeEnable: - tlsConfig = &tls.Config{ - InsecureSkipVerify: true, - } - case config.DBTLSModeRequire: - tlsConfig = &tls.Config{ - InsecureSkipVerify: false, - ServerName: c.DBConfig.Address, - } - } - - if tlsConfig != nil && c.DBConfig.TLSCACert != "" { - // load the system cert pool first -- we'll append the given CA cert to this - certPool, err := x509.SystemCertPool() - if err != nil { - return nil, fmt.Errorf("error fetching system CA cert pool: %s", err) - } - - // open the file itself and make sure there's something in it - caCertBytes, err := os.ReadFile(c.DBConfig.TLSCACert) - if err != nil { - return nil, fmt.Errorf("error opening CA certificate at %s: %s", c.DBConfig.TLSCACert, err) - } - if len(caCertBytes) == 0 { - return nil, fmt.Errorf("ca cert at %s was empty", c.DBConfig.TLSCACert) - } - - // make sure we have a PEM block - caPem, _ := pem.Decode(caCertBytes) - if caPem == nil { - return nil, fmt.Errorf("could not parse cert at %s into PEM", c.DBConfig.TLSCACert) - } - - // parse the PEM block into the certificate - caCert, err := x509.ParseCertificate(caPem.Bytes) - if err != nil { - return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", c.DBConfig.TLSCACert, err) - } - - // we're happy, add it to the existing pool and then use this pool in our tls config - certPool.AddCert(caCert) - tlsConfig.RootCAs = certPool - } - - // We can rely on the pg library we're using to set - // sensible defaults for everything we don't set here. - options := &pg.Options{ - Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port), - User: c.DBConfig.User, - Password: c.DBConfig.Password, - Database: c.DBConfig.Database, - ApplicationName: c.ApplicationName, - TLSConfig: tlsConfig, - } - - return options, nil -} - -/* - CONVERSION FUNCTIONS -*/ - -// TODO: move these to the type converter, it's bananas that they're here and not there - -func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) { - ogAccount := >smodel.Account{} - if err := ps.conn.Model(ogAccount).Where("id = ?", originAccountID).Select(); err != nil { - return nil, err - } - - menchies := []*gtsmodel.Mention{} - for _, a := range targetAccounts { - // A mentioned account looks like "@test@example.org" or just "@test" for a local account - // -- we can guarantee this from the regex that targetAccounts should have been derived from. - // But we still need to do a bit of fiddling to get what we need here -- the username and domain (if given). - - // 1. trim off the first @ - t := strings.TrimPrefix(a, "@") - - // 2. split the username and domain - s := strings.Split(t, "@") - - // 3. if it's length 1 it's a local account, length 2 means remote, anything else means something is wrong - var local bool - switch len(s) { - case 1: - local = true - case 2: - local = false - default: - return nil, fmt.Errorf("mentioned account format '%s' was not valid", a) - } - - var username, domain string - username = s[0] - if !local { - domain = s[1] - } - - // 4. check we now have a proper username and domain - if username == "" || (!local && domain == "") { - return nil, fmt.Errorf("username or domain for '%s' was nil", a) - } - - // okay we're good now, we can start pulling accounts out of the database - mentionedAccount := >smodel.Account{} - var err error - - // match username + account, case insensitive - if local { - // local user -- should have a null domain - err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("? IS NULL", pg.Ident("domain")).Select() - } else { - // remote user -- should have domain defined - err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("LOWER(?) = LOWER(?)", pg.Ident("domain"), domain).Select() - } - - if err != nil { - if err == pg.ErrNoRows { - // no result found for this username/domain so just don't include it as a mencho and carry on about our business - ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain) - continue - } - // a serious error has happened so bail - return nil, fmt.Errorf("error getting account with username '%s' and domain '%s': %s", username, domain, err) - } - - // id, createdAt and updatedAt will be populated by the db, so we have everything we need! - menchies = append(menchies, >smodel.Mention{ - StatusID: statusID, - OriginAccountID: ogAccount.ID, - OriginAccountURI: ogAccount.URI, - TargetAccountID: mentionedAccount.ID, - NameString: a, - TargetAccountURI: mentionedAccount.URI, - TargetAccountURL: mentionedAccount.URL, - OriginAccount: mentionedAccount, - }) - } - return menchies, nil -} - -func (ps *postgresService) TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) { - newTags := []*gtsmodel.Tag{} - for _, t := range tags { - tag := >smodel.Tag{} - // we can use selectorinsert here to create the new tag if it doesn't exist already - // inserted will be true if this is a new tag we just created - if err := ps.conn.Model(tag).Where("LOWER(?) = LOWER(?)", pg.Ident("name"), t).Select(); err != nil { - if err == pg.ErrNoRows { - // tag doesn't exist yet so populate it - newID, err := id.NewRandomULID() - if err != nil { - return nil, err - } - tag.ID = newID - tag.URL = fmt.Sprintf("%s://%s/tags/%s", ps.config.Protocol, ps.config.Host, t) - tag.Name = t - tag.FirstSeenFromAccountID = originAccountID - tag.CreatedAt = time.Now() - tag.UpdatedAt = time.Now() - tag.Useable = true - tag.Listable = true - } else { - return nil, fmt.Errorf("error getting tag with name %s: %s", t, err) - } - } - - // bail already if the tag isn't useable - if !tag.Useable { - continue - } - tag.LastStatusAt = time.Now() - newTags = append(newTags, tag) - } - return newTags, nil -} - -func (ps *postgresService) EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) { - newEmojis := []*gtsmodel.Emoji{} - for _, e := range emojis { - emoji := >smodel.Emoji{} - err := ps.conn.Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Select() - if err != nil { - if err == pg.ErrNoRows { - // no result found for this username/domain so just don't include it as an emoji and carry on about our business - ps.log.Debugf("no emoji found with shortcode %s, skipping it", e) - continue - } - // a serious error has happened so bail - return nil, fmt.Errorf("error getting emoji with shortcode %s: %s", e, err) - } - newEmojis = append(newEmojis, emoji) - } - return newEmojis, nil -} diff --git a/internal/db/pg/pg_test.go b/internal/db/pg/pg_test.go deleted file mode 100644 index c1e10abdf..000000000 --- a/internal/db/pg/pg_test.go +++ /dev/null @@ -1,47 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg_test - -import ( - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/oauth" -) - -type PGStandardTestSuite struct { - // standard suite interfaces - suite.Suite - config *config.Config - db db.DB - log *logrus.Logger - - // standard suite models - testTokens map[string]*oauth.Token - testClients map[string]*oauth.Client - testApplications map[string]*gtsmodel.Application - testUsers map[string]*gtsmodel.User - testAccounts map[string]*gtsmodel.Account - testAttachments map[string]*gtsmodel.MediaAttachment - testStatuses map[string]*gtsmodel.Status - testTags map[string]*gtsmodel.Tag - testMentions map[string]*gtsmodel.Mention -} diff --git a/internal/db/pg/relationship.go b/internal/db/pg/relationship.go deleted file mode 100644 index 76bd50c76..000000000 --- a/internal/db/pg/relationship.go +++ /dev/null @@ -1,276 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - "fmt" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -type relationshipDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc -} - -func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query { - return r.conn.Model(block). - Relation("Account"). - Relation("TargetAccount") -} - -func (r *relationshipDB) newFollowQ(follow interface{}) *orm.Query { - return r.conn.Model(follow). - Relation("Account"). - Relation("TargetAccount") -} - -func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) { - q := r.conn. - Model(>smodel.Block{}). - Where("account_id = ?", account1). - Where("target_account_id = ?", account2) - - if eitherDirection { - q = q. - WhereOr("target_account_id = ?", account1). - Where("account_id = ?", account2) - } - - return q.Exists() -} - -func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) { - block := >smodel.Block{} - - q := r.newBlockQ(block). - Where("block.account_id = ?", account1). - Where("block.target_account_id = ?", account2) - - err := processErrorResponse(q.Select()) - - return block, err -} - -func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { - rel := >smodel.Relationship{ - ID: targetAccount, - } - - // check if the requesting account follows the target account - follow := >smodel.Follow{} - if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil { - if err != pg.ErrNoRows { - // a proper error - return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) - } - // no follow exists so these are all false - rel.Following = false - rel.ShowingReblogs = false - rel.Notifying = false - } else { - // follow exists so we can fill these fields out... - rel.Following = true - rel.ShowingReblogs = follow.ShowReblogs - rel.Notifying = follow.Notify - } - - // check if the target account follows the requesting account - followedBy, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() - if err != nil { - return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) - } - rel.FollowedBy = followedBy - - // check if the requesting account blocks the target account - blocking, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() - if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) - } - rel.Blocking = blocking - - // check if the target account blocks the requesting account - blockedBy, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() - if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) - } - rel.BlockedBy = blockedBy - - // check if there's a pending following request from requesting account to target account - requested, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() - if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) - } - rel.Requested = requested - - return rel, nil -} - -func (r *relationshipDB) IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { - if sourceAccount == nil || targetAccount == nil { - return false, nil - } - - q := r.conn. - Model(>smodel.Follow{}). - Where("account_id = ?", sourceAccount.ID). - Where("target_account_id = ?", targetAccount.ID) - - return q.Exists() -} - -func (r *relationshipDB) IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { - if sourceAccount == nil || targetAccount == nil { - return false, nil - } - - q := r.conn. - Model(>smodel.FollowRequest{}). - Where("account_id = ?", sourceAccount.ID). - Where("target_account_id = ?", targetAccount.ID) - - return q.Exists() -} - -func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { - if account1 == nil || account2 == nil { - return false, nil - } - - // make sure account 1 follows account 2 - f1, err := r.IsFollowing(account1, account2) - if err != nil { - return false, processErrorResponse(err) - } - - // make sure account 2 follows account 1 - f2, err := r.IsFollowing(account2, account1) - if err != nil { - return false, processErrorResponse(err) - } - - return f1 && f2, nil -} - -func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { - // make sure the original follow request exists - fr := >smodel.FollowRequest{} - if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil { - if err == pg.ErrMultiRows { - return nil, db.ErrNoEntries - } - return nil, err - } - - // create a new follow to 'replace' the request with - follow := >smodel.Follow{ - ID: fr.ID, - AccountID: originAccountID, - TargetAccountID: targetAccountID, - URI: fr.URI, - } - - // if the follow already exists, just update the URI -- we don't need to do anything else - if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil { - return nil, err - } - - // now remove the follow request - if _, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil { - return nil, err - } - - return follow, nil -} - -func (r *relationshipDB) GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, db.Error) { - followRequests := []*gtsmodel.FollowRequest{} - - q := r.newFollowQ(&followRequests). - Where("target_account_id = ?", accountID) - - err := processErrorResponse(q.Select()) - - return followRequests, err -} - -func (r *relationshipDB) GetAccountFollows(accountID string) ([]*gtsmodel.Follow, db.Error) { - follows := []*gtsmodel.Follow{} - - q := r.newFollowQ(&follows). - Where("account_id = ?", accountID) - - err := processErrorResponse(q.Select()) - - return follows, err -} - -func (r *relationshipDB) CountAccountFollows(accountID string, localOnly bool) (int, db.Error) { - return r.conn. - Model(&[]*gtsmodel.Follow{}). - Where("account_id = ?", accountID). - Count() -} - -func (r *relationshipDB) GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { - - follows := []*gtsmodel.Follow{} - - q := r.conn.Model(&follows) - - if localOnly { - // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe - whereGroup := func(q *pg.Query) (*pg.Query, error) { - q = q. - WhereOr("? IS NULL", pg.Ident("a.domain")). - WhereOr("a.domain = ?", "") - return q, nil - } - - q = q.ColumnExpr("follow.*"). - Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)"). - Where("follow.target_account_id = ?", accountID). - WhereGroup(whereGroup) - } else { - q = q.Where("target_account_id = ?", accountID) - } - - if err := q.Select(); err != nil { - if err == pg.ErrNoRows { - return follows, nil - } - return nil, err - } - return follows, nil -} - -func (r *relationshipDB) CountAccountFollowedBy(accountID string, localOnly bool) (int, db.Error) { - return r.conn. - Model(&[]*gtsmodel.Follow{}). - Where("target_account_id = ?", accountID). - Count() -} diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go deleted file mode 100644 index 99790428e..000000000 --- a/internal/db/pg/status.go +++ /dev/null @@ -1,318 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "container/list" - "context" - "errors" - "time" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/cache" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -type statusDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc - cache cache.Cache -} - -func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) { - if s.cache == nil { - s.cache = cache.New() - } - - if err := s.cache.Store(id, status); err != nil { - s.log.Panicf("statusDB: error storing in cache: %s", err) - } -} - -func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) { - if s.cache == nil { - s.cache = cache.New() - return nil, false - } - - sI, err := s.cache.Fetch(id) - if err != nil || sI == nil { - return nil, false - } - - status, ok := sI.(*gtsmodel.Status) - if !ok { - s.log.Panicf("statusDB: cached interface with key %s was not a status", id) - } - - return status, true -} - -func (s *statusDB) newStatusQ(status interface{}) *orm.Query { - return s.conn.Model(status). - Relation("Attachments"). - Relation("Tags"). - Relation("Mentions"). - Relation("Emojis"). - Relation("Account"). - Relation("InReplyTo"). - Relation("InReplyToAccount"). - Relation("BoostOf"). - Relation("BoostOfAccount"). - Relation("CreatedWithApplication") -} - -func (s *statusDB) newFaveQ(faves interface{}) *orm.Query { - return s.conn.Model(faves). - Relation("Account"). - Relation("TargetAccount"). - Relation("Status") -} - -func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(id); cached { - return status, nil - } - - status := >smodel.Status{} - - q := s.newStatusQ(status). - Where("status.id = ?", id) - - err := processErrorResponse(q.Select()) - - if err == nil && status != nil { - s.cacheStatus(id, status) - } - - return status, err -} - -func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(uri); cached { - return status, nil - } - - status := >smodel.Status{} - - q := s.newStatusQ(status). - Where("LOWER(status.uri) = LOWER(?)", uri) - - err := processErrorResponse(q.Select()) - - if err == nil && status != nil { - s.cacheStatus(uri, status) - } - - return status, err -} - -func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(uri); cached { - return status, nil - } - - status := >smodel.Status{} - - q := s.newStatusQ(status). - Where("LOWER(status.url) = LOWER(?)", uri) - - err := processErrorResponse(q.Select()) - - if err == nil && status != nil { - s.cacheStatus(uri, status) - } - - return status, err -} - -func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error { - transaction := func(tx *pg.Tx) error { - // create links between this status and any emojis it uses - for _, i := range status.EmojiIDs { - if _, err := tx.Model(>smodel.StatusToEmoji{ - StatusID: status.ID, - EmojiID: i, - }).Insert(); err != nil { - return err - } - } - - // create links between this status and any tags it uses - for _, i := range status.TagIDs { - if _, err := tx.Model(>smodel.StatusToTag{ - StatusID: status.ID, - TagID: i, - }).Insert(); err != nil { - return err - } - } - - // change the status ID of the media attachments to the new status - for _, a := range status.Attachments { - a.StatusID = status.ID - a.UpdatedAt = time.Now() - if _, err := s.conn.Model(a). - Where("id = ?", a.ID). - Update(); err != nil { - return err - } - } - - _, err := tx.Model(status).Insert() - return err - } - - return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction)) -} - -func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { - parents := []*gtsmodel.Status{} - s.statusParent(status, &parents, onlyDirect) - - return parents, nil -} - -func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { - if status.InReplyToID == "" { - return - } - - parentStatus, err := s.GetStatusByID(status.InReplyToID) - if err == nil { - *foundStatuses = append(*foundStatuses, parentStatus) - } - - if onlyDirect { - return - } - - s.statusParent(parentStatus, foundStatuses, false) -} - -func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { - foundStatuses := &list.List{} - foundStatuses.PushFront(status) - s.statusChildren(status, foundStatuses, onlyDirect, minID) - - children := []*gtsmodel.Status{} - for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - - // only append children, not the overall parent status - if entry.ID != status.ID { - children = append(children, entry) - } - } - - return children, nil -} - -func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { - immediateChildren := []*gtsmodel.Status{} - - q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID) - if minID != "" { - q = q.Where("status.id > ?", minID) - } - - if err := q.Select(); err != nil { - return - } - - for _, child := range immediateChildren { - insertLoop: - for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - - if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { - foundStatuses.InsertAfter(child, e) - break insertLoop - } - } - - // only do one loop if we only want direct children - if onlyDirect { - return - } - s.statusChildren(child, foundStatuses, false, minID) - } -} - -func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) { - return s.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count() -} - -func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) { - return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count() -} - -func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) { - return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count() -} - -func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) { - return s.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() -} - -func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { - faves := []*gtsmodel.StatusFave{} - - q := s.newFaveQ(&faves). - Where("status_id = ?", status.ID) - - err := processErrorResponse(q.Select()) - - return faves, err -} - -func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { - reblogs := []*gtsmodel.Status{} - - q := s.newStatusQ(&reblogs). - Where("boost_of_id = ?", status.ID) - - err := processErrorResponse(q.Select()) - - return reblogs, err -} diff --git a/internal/db/pg/status_test.go b/internal/db/pg/status_test.go deleted file mode 100644 index 8a185757c..000000000 --- a/internal/db/pg/status_test.go +++ /dev/null @@ -1,134 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg_test - -import ( - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/testrig" -) - -type StatusTestSuite struct { - PGStandardTestSuite -} - -func (suite *StatusTestSuite) SetupSuite() { - suite.testTokens = testrig.NewTestTokens() - suite.testClients = testrig.NewTestClients() - suite.testApplications = testrig.NewTestApplications() - suite.testUsers = testrig.NewTestUsers() - suite.testAccounts = testrig.NewTestAccounts() - suite.testAttachments = testrig.NewTestAttachments() - suite.testStatuses = testrig.NewTestStatuses() - suite.testTags = testrig.NewTestTags() - suite.testMentions = testrig.NewTestMentions() -} - -func (suite *StatusTestSuite) SetupTest() { - suite.config = testrig.NewTestConfig() - suite.db = testrig.NewTestDB() - suite.log = testrig.NewTestLog() - - testrig.StandardDBSetup(suite.db, suite.testAccounts) -} - -func (suite *StatusTestSuite) TearDownTest() { - testrig.StandardDBTeardown(suite.db) -} - -func (suite *StatusTestSuite) TestGetStatusByID() { - status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID) - if err != nil { - suite.FailNow(err.Error()) - } - suite.NotNil(status) - suite.NotNil(status.Account) - suite.NotNil(status.CreatedWithApplication) - suite.Nil(status.BoostOf) - suite.Nil(status.BoostOfAccount) - suite.Nil(status.InReplyTo) - suite.Nil(status.InReplyToAccount) -} - -func (suite *StatusTestSuite) TestGetStatusByURI() { - status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) - if err != nil { - suite.FailNow(err.Error()) - } - suite.NotNil(status) - suite.NotNil(status.Account) - suite.NotNil(status.CreatedWithApplication) - suite.Nil(status.BoostOf) - suite.Nil(status.BoostOfAccount) - suite.Nil(status.InReplyTo) - suite.Nil(status.InReplyToAccount) -} - -func (suite *StatusTestSuite) TestGetStatusWithExtras() { - status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID) - if err != nil { - suite.FailNow(err.Error()) - } - suite.NotNil(status) - suite.NotNil(status.Account) - suite.NotNil(status.CreatedWithApplication) - suite.NotEmpty(status.Tags) - suite.NotEmpty(status.Attachments) - suite.NotEmpty(status.Emojis) -} - -func (suite *StatusTestSuite) TestGetStatusWithMention() { - status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID) - if err != nil { - suite.FailNow(err.Error()) - } - suite.NotNil(status) - suite.NotNil(status.Account) - suite.NotNil(status.CreatedWithApplication) - suite.NotEmpty(status.Mentions) - suite.NotEmpty(status.MentionIDs) - suite.NotNil(status.InReplyTo) - suite.NotNil(status.InReplyToAccount) -} - -func (suite *StatusTestSuite) TestGetStatusTwice() { - before1 := time.Now() - _, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) - suite.NoError(err) - after1 := time.Now() - duration1 := after1.Sub(before1) - fmt.Println(duration1.Nanoseconds()) - - before2 := time.Now() - _, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) - suite.NoError(err) - after2 := time.Now() - duration2 := after2.Sub(before2) - fmt.Println(duration2.Nanoseconds()) - - // second retrieval should be several orders faster since it will be cached now - suite.Less(duration2, duration1) -} - -func TestStatusTestSuite(t *testing.T) { - suite.Run(t, new(StatusTestSuite)) -} diff --git a/internal/db/pg/timeline.go b/internal/db/pg/timeline.go deleted file mode 100644 index fa8b07aab..000000000 --- a/internal/db/pg/timeline.go +++ /dev/null @@ -1,210 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see <http://www.gnu.org/licenses/>. -*/ - -package pg - -import ( - "context" - "sort" - - "github.com/go-pg/pg/v10" - "github.com/sirupsen/logrus" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -type timelineDB struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc -} - -func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { - statuses := []*gtsmodel.Status{} - q := t.conn.Model(&statuses) - - q = q.ColumnExpr("status.*"). - // Find out who accountID follows. - Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id"). - // Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows, - // OR statuses posted by accountID itself (since a user should be able to see their own statuses). - // - // This is equivalent to something like WHERE ... AND (... OR ...) - // See: https://pg.uptrace.dev/queries/#select - WhereGroup(func(q *pg.Query) (*pg.Query, error) { - q = q.WhereOr("f.account_id = ?", accountID). - WhereOr("status.account_id = ?", accountID) - return q, nil - }). - // Sort by highest ID (newest) to lowest ID (oldest) - Order("status.id DESC") - - if maxID != "" { - // return only statuses LOWER (ie., older) than maxID - q = q.Where("status.id < ?", maxID) - } - - if sinceID != "" { - // return only statuses HIGHER (ie., newer) than sinceID - q = q.Where("status.id > ?", sinceID) - } - - if minID != "" { - // return only statuses HIGHER (ie., newer) than minID - q = q.Where("status.id > ?", minID) - } - - if local { - // return only statuses posted by local account havers - q = q.Where("status.local = ?", local) - } - - if limit > 0 { - // limit amount of statuses returned - q = q.Limit(limit) - } - - err := q.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } - return nil, err - } - - if len(statuses) == 0 { - return nil, db.ErrNoEntries - } - - return statuses, nil -} - -func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { - statuses := []*gtsmodel.Status{} - - q := t.conn.Model(&statuses). - Where("visibility = ?", gtsmodel.VisibilityPublic). - Where("? IS NULL", pg.Ident("in_reply_to_id")). - Where("? IS NULL", pg.Ident("in_reply_to_uri")). - Where("? IS NULL", pg.Ident("boost_of_id")). - Order("status.id DESC") - - if maxID != "" { - q = q.Where("status.id < ?", maxID) - } - - if sinceID != "" { - q = q.Where("status.id > ?", sinceID) - } - - if minID != "" { - q = q.Where("status.id > ?", minID) - } - - if local { - q = q.Where("status.local = ?", local) - } - - if limit > 0 { - q = q.Limit(limit) - } - - err := q.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, db.ErrNoEntries - } - return nil, err - } - - if len(statuses) == 0 { - return nil, db.ErrNoEntries - } - - return statuses, nil -} - -// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! -// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds. -func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) { - - faves := []*gtsmodel.StatusFave{} - - fq := t.conn.Model(&faves). - Where("account_id = ?", accountID). - Order("id DESC") - - if maxID != "" { - fq = fq.Where("id < ?", maxID) - } - - if minID != "" { - fq = fq.Where("id > ?", minID) - } - - if limit > 0 { - fq = fq.Limit(limit) - } - - err := fq.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries - } - return nil, "", "", err - } - - if len(faves) == 0 { - return nil, "", "", db.ErrNoEntries - } - - // map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID - statusesFavesMap := map[string]string{} - - in := []string{} - for _, f := range faves { - statusesFavesMap[f.StatusID] = f.ID - in = append(in, f.StatusID) - } - - statuses := []*gtsmodel.Status{} - err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries - } - return nil, "", "", err - } - - if len(statuses) == 0 { - return nil, "", "", db.ErrNoEntries - } - - // arrange statuses by fave ID - sort.Slice(statuses, func(i int, j int) bool { - statusI := statuses[i] - statusJ := statuses[j] - return statusesFavesMap[statusI.ID] < statusesFavesMap[statusJ.ID] - }) - - nextMaxID := faves[len(faves)-1].ID - prevMinID := faves[0].ID - return statuses, nextMaxID, prevMinID, nil -} diff --git a/internal/db/pg/util.go b/internal/db/pg/util.go deleted file mode 100644 index 17c09b720..000000000 --- a/internal/db/pg/util.go +++ /dev/null @@ -1,25 +0,0 @@ -package pg - -import ( - "strings" - - "github.com/go-pg/pg/v10" - "github.com/superseriousbusiness/gotosocial/internal/db" -) - -// processErrorResponse parses the given error and returns an appropriate DBError. -func processErrorResponse(err error) db.Error { - switch err { - case nil: - return nil - case pg.ErrNoRows: - return db.ErrNoEntries - case pg.ErrMultiRows: - return db.ErrMultipleEntries - default: - if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists - } - return err - } -} |
