diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/db.go | 29 | ||||
-rw-r--r-- | internal/db/pg/pg.go | 157 |
2 files changed, 142 insertions, 44 deletions
diff --git a/internal/db/db.go b/internal/db/db.go index cbcd698c9..5609b926f 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -22,7 +22,6 @@ import ( "context" "net" - "github.com/go-fed/activity/pub" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -33,18 +32,28 @@ const ( // ErrNoEntries is to be returned from the DB interface when no entries are found for a given query. type ErrNoEntries struct{} - func (e ErrNoEntries) Error() string { return "no entries" } +// ErrAlreadyExists is to be returned from the DB interface when an entry already exists for a given query or its constraints. +type ErrAlreadyExists struct{} +func (e ErrAlreadyExists) Error() string { + return "already exists" +} + +type Where struct { + Key string + Value interface{} +} + // DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres). // Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated // by whatever is returned from the database. type DB interface { // Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions. // See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database - Federation() pub.Database + // Federation() federatingdb.FederatingDB /* BASIC DB FUNCTIONALITY @@ -75,7 +84,7 @@ type DB interface { // name of the key to select from. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetWhere(key string, value interface{}, i interface{}) error + GetWhere(where []Where, i interface{}) error // // GetWhereMany gets one entry where key = value for *ALL* parameters passed as "where". // // That is, if you pass 2 'where' entries, with 1 being Key username and Value test, and the second @@ -109,7 +118,7 @@ type DB interface { // DeleteWhere deletes i where key = value // If i didn't exist anyway, then no error should be returned. - DeleteWhere(key string, value interface{}, i interface{}) error + DeleteWhere(where []Where, i interface{}) error /* HANDY SHORTCUTS @@ -117,7 +126,9 @@ type DB interface { // AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table. // In other words, it should create the follow, and delete the existing follow request. - AcceptFollowRequest(originAccountID string, targetAccountID string) error + // + // It will return the newly created follow for further processing. + AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) // CreateInstanceAccount creates an account in the database with the same username as the instance host value. // Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'. @@ -204,6 +215,9 @@ type DB interface { // That is, it returns true if account1 blocks account2, OR if account2 blocks account1. Blocked(account1 string, account2 string) (bool, error) + // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount. + GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) + // StatusVisible returns true if targetStatus is visible to requestingAccount, based on the // privacy settings of the status, and any blocks/mutes that might exist between the two accounts // or account domains. @@ -222,6 +236,9 @@ type DB interface { // Follows returns true if sourceAccount follows target account, or an error if something goes wrong while finding out. Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) + // FollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out. + FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) + // Mutuals returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out. Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go index d3590a027..2a4b040d1 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/pg/pg.go @@ -30,7 +30,6 @@ import ( "strings" "time" - "github.com/go-fed/activity/pub" "github.com/go-pg/pg/extra/pgdebug" "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" @@ -38,7 +37,6 @@ import ( "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/util" "golang.org/x/crypto/bcrypt" @@ -46,11 +44,11 @@ import ( // postgresService satisfies the DB interface type postgresService struct { - config *config.Config - conn *pg.DB - log *logrus.Logger - cancel context.CancelFunc - federationDB pub.Database + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc + // federationDB pub.Database } // NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. @@ -97,9 +95,6 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge cancel: cancel, } - federatingDB := federation.NewFederatingDB(ps, c, log) - ps.federationDB = federatingDB - // we can confidently return this useable postgres service now return ps, nil } @@ -160,14 +155,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { } /* - FEDERATION FUNCTIONALITY -*/ - -func (ps *postgresService) Federation() pub.Database { - return ps.federationDB -} - -/* BASIC DB FUNCTIONALITY */ @@ -229,8 +216,17 @@ func (ps *postgresService) GetByID(id string, i interface{}) error { return nil } -func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error { - if err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Select(); err != nil { +func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error { + if len(where) == 0 { + return errors.New("no queries provided") + } + + q := ps.conn.Model(i) + for _, w := range where { + q = q.Where("? = ?", pg.Safe(w.Key), w.Value) + } + + if err := q.Select(); err != nil { if err == pg.ErrNoRows { return db.ErrNoEntries{} } @@ -255,6 +251,9 @@ func (ps *postgresService) GetAll(i interface{}) error { func (ps *postgresService) Put(i interface{}) error { _, err := ps.conn.Model(i).Insert(i) + if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return db.ErrAlreadyExists{} + } return err } @@ -285,20 +284,31 @@ func (ps *postgresService) UpdateOneByID(id string, key string, value interface{ func (ps *postgresService) DeleteByID(id string, i interface{}) error { if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} + // if there are no rows *anyway* then that's fine + // just return err if there's an actual error + if err != pg.ErrNoRows { + return err } - return err } return nil } -func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error { - if _, err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Delete(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} +func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error { + if len(where) == 0 { + return errors.New("no queries provided") + } + + q := ps.conn.Model(i) + for _, w := range where { + q = q.Where("? = ?", pg.Safe(w.Key), w.Value) + } + + if _, err := q.Delete(); err != nil { + // if there are no rows *anyway* then that's fine + // just return err if there's an actual error + if err != pg.ErrNoRows { + return err } - return err } return nil } @@ -307,30 +317,34 @@ func (ps *postgresService) DeleteWhere(key string, value interface{}, i interfac HANDY SHORTCUTS */ -func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) error { +func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) { + // make sure the original follow request exists fr := >smodel.FollowRequest{} if err := ps.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil { if err == pg.ErrMultiRows { - return db.ErrNoEntries{} + return nil, db.ErrNoEntries{} } - return err + return nil, err } + // create a new follow to 'replace' the request with follow := >smodel.Follow{ AccountID: originAccountID, TargetAccountID: targetAccountID, URI: fr.URI, } - if _, err := ps.conn.Model(follow).Insert(); err != nil { - return err + // if the follow already exists, just update the URI -- we don't need to do anything else + if _, err := ps.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil { + return nil, err } + // now remove the follow request if _, err := ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil { - return err + return nil, err } - return nil + return follow, nil } func (ps *postgresService) CreateInstanceAccount() error { @@ -681,6 +695,60 @@ func (ps *postgresService) Blocked(account1 string, account2 string) (bool, erro return blocked, nil } +func (ps *postgresService) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) { + r := >smodel.Relationship{ + ID: targetAccount, + } + + // check if the requesting account follows the target account + follow := >smodel.Follow{} + if err := ps.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil { + if err != pg.ErrNoRows { + // a proper error + return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) + } + // no follow exists so these are all false + r.Following = false + r.ShowingReblogs = false + r.Notifying = false + } else { + // follow exists so we can fill these fields out... + r.Following = true + r.ShowingReblogs = follow.ShowReblogs + r.Notifying = follow.Notify + } + + // check if the target account follows the requesting account + followedBy, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) + } + r.FollowedBy = followedBy + + // check if the requesting account blocks the target account + blocking, err := ps.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) + } + r.Blocking = blocking + + // check if the target account blocks the requesting account + blockedBy, err := ps.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) + } + r.BlockedBy = blockedBy + + // check if there's a pending following request from requesting account to target account + requested, err := ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() + if err != nil { + return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) + } + r.Requested = requested + + return r, nil +} + func (ps *postgresService) StatusVisible(targetStatus *gtsmodel.Status, targetAccount *gtsmodel.Account, requestingAccount *gtsmodel.Account, relevantAccounts *gtsmodel.RelevantAccounts) (bool, error) { l := ps.log.WithField("func", "StatusVisible") @@ -853,6 +921,10 @@ func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccoun return ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() } +func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) { + return ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() +} + func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) { // make sure account 1 follows account 2 f1, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists() @@ -1036,6 +1108,11 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel. */ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) { + ogAccount := >smodel.Account{} + if err := ps.conn.Model(ogAccount).Where("id = ?", originAccountID).Select(); err != nil { + return nil, err + } + menchies := []*gtsmodel.Mention{} for _, a := range targetAccounts { // A mentioned account looks like "@test@example.org" or just "@test" for a local account @@ -1093,9 +1170,13 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori // id, createdAt and updatedAt will be populated by the db, so we have everything we need! menchies = append(menchies, >smodel.Mention{ - StatusID: statusID, - OriginAccountID: originAccountID, - TargetAccountID: mentionedAccount.ID, + StatusID: statusID, + OriginAccountID: ogAccount.ID, + OriginAccountURI: ogAccount.URI, + TargetAccountID: mentionedAccount.ID, + NameString: a, + MentionedAccountURI: mentionedAccount.URI, + GTSAccount: mentionedAccount, }) } return menchies, nil |