summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/db.go29
-rw-r--r--internal/db/pg/pg.go157
2 files changed, 142 insertions, 44 deletions
diff --git a/internal/db/db.go b/internal/db/db.go
index cbcd698c9..5609b926f 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -22,7 +22,6 @@ import (
"context"
"net"
- "github.com/go-fed/activity/pub"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -33,18 +32,28 @@ const (
// ErrNoEntries is to be returned from the DB interface when no entries are found for a given query.
type ErrNoEntries struct{}
-
func (e ErrNoEntries) Error() string {
return "no entries"
}
+// ErrAlreadyExists is to be returned from the DB interface when an entry already exists for a given query or its constraints.
+type ErrAlreadyExists struct{}
+func (e ErrAlreadyExists) Error() string {
+ return "already exists"
+}
+
+type Where struct {
+ Key string
+ Value interface{}
+}
+
// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres).
// Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated
// by whatever is returned from the database.
type DB interface {
// Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions.
// See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database
- Federation() pub.Database
+ // Federation() federatingdb.FederatingDB
/*
BASIC DB FUNCTIONALITY
@@ -75,7 +84,7 @@ type DB interface {
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetWhere(key string, value interface{}, i interface{}) error
+ GetWhere(where []Where, i interface{}) error
// // GetWhereMany gets one entry where key = value for *ALL* parameters passed as "where".
// // That is, if you pass 2 'where' entries, with 1 being Key username and Value test, and the second
@@ -109,7 +118,7 @@ type DB interface {
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
- DeleteWhere(key string, value interface{}, i interface{}) error
+ DeleteWhere(where []Where, i interface{}) error
/*
HANDY SHORTCUTS
@@ -117,7 +126,9 @@ type DB interface {
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
- AcceptFollowRequest(originAccountID string, targetAccountID string) error
+ //
+ // It will return the newly created follow for further processing.
+ AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
@@ -204,6 +215,9 @@ type DB interface {
// That is, it returns true if account1 blocks account2, OR if account2 blocks account1.
Blocked(account1 string, account2 string) (bool, error)
+ // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
+ GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error)
+
// StatusVisible returns true if targetStatus is visible to requestingAccount, based on the
// privacy settings of the status, and any blocks/mutes that might exist between the two accounts
// or account domains.
@@ -222,6 +236,9 @@ type DB interface {
// Follows returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
+ // FollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
+ FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
+
// Mutuals returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error)
diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go
index d3590a027..2a4b040d1 100644
--- a/internal/db/pg/pg.go
+++ b/internal/db/pg/pg.go
@@ -30,7 +30,6 @@ import (
"strings"
"time"
- "github.com/go-fed/activity/pub"
"github.com/go-pg/pg/extra/pgdebug"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
@@ -38,7 +37,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/crypto/bcrypt"
@@ -46,11 +44,11 @@ import (
// postgresService satisfies the DB interface
type postgresService struct {
- config *config.Config
- conn *pg.DB
- log *logrus.Logger
- cancel context.CancelFunc
- federationDB pub.Database
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+ // federationDB pub.Database
}
// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
@@ -97,9 +95,6 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
cancel: cancel,
}
- federatingDB := federation.NewFederatingDB(ps, c, log)
- ps.federationDB = federatingDB
-
// we can confidently return this useable postgres service now
return ps, nil
}
@@ -160,14 +155,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
}
/*
- FEDERATION FUNCTIONALITY
-*/
-
-func (ps *postgresService) Federation() pub.Database {
- return ps.federationDB
-}
-
-/*
BASIC DB FUNCTIONALITY
*/
@@ -229,8 +216,17 @@ func (ps *postgresService) GetByID(id string, i interface{}) error {
return nil
}
-func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error {
- if err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Select(); err != nil {
+func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := ps.conn.Model(i)
+ for _, w := range where {
+ q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
+ }
+
+ if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
@@ -255,6 +251,9 @@ func (ps *postgresService) GetAll(i interface{}) error {
func (ps *postgresService) Put(i interface{}) error {
_, err := ps.conn.Model(i).Insert(i)
+ if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists{}
+ }
return err
}
@@ -285,20 +284,31 @@ func (ps *postgresService) UpdateOneByID(id string, key string, value interface{
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
+ // if there are no rows *anyway* then that's fine
+ // just return err if there's an actual error
+ if err != pg.ErrNoRows {
+ return err
}
- return err
}
return nil
}
-func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error {
- if _, err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Delete(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
+func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error {
+ if len(where) == 0 {
+ return errors.New("no queries provided")
+ }
+
+ q := ps.conn.Model(i)
+ for _, w := range where {
+ q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
+ }
+
+ if _, err := q.Delete(); err != nil {
+ // if there are no rows *anyway* then that's fine
+ // just return err if there's an actual error
+ if err != pg.ErrNoRows {
+ return err
}
- return err
}
return nil
}
@@ -307,30 +317,34 @@ func (ps *postgresService) DeleteWhere(key string, value interface{}, i interfac
HANDY SHORTCUTS
*/
-func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) error {
+func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
+ // make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
if err := ps.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
if err == pg.ErrMultiRows {
- return db.ErrNoEntries{}
+ return nil, db.ErrNoEntries{}
}
- return err
+ return nil, err
}
+ // create a new follow to 'replace' the request with
follow := &gtsmodel.Follow{
AccountID: originAccountID,
TargetAccountID: targetAccountID,
URI: fr.URI,
}
- if _, err := ps.conn.Model(follow).Insert(); err != nil {
- return err
+ // if the follow already exists, just update the URI -- we don't need to do anything else
+ if _, err := ps.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
+ return nil, err
}
+ // now remove the follow request
if _, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
- return err
+ return nil, err
}
- return nil
+ return follow, nil
}
func (ps *postgresService) CreateInstanceAccount() error {
@@ -681,6 +695,60 @@ func (ps *postgresService) Blocked(account1 string, account2 string) (bool, erro
return blocked, nil
}
+func (ps *postgresService) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) {
+ r := &gtsmodel.Relationship{
+ ID: targetAccount,
+ }
+
+ // check if the requesting account follows the target account
+ follow := &gtsmodel.Follow{}
+ if err := ps.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
+ if err != pg.ErrNoRows {
+ // a proper error
+ return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
+ }
+ // no follow exists so these are all false
+ r.Following = false
+ r.ShowingReblogs = false
+ r.Notifying = false
+ } else {
+ // follow exists so we can fill these fields out...
+ r.Following = true
+ r.ShowingReblogs = follow.ShowReblogs
+ r.Notifying = follow.Notify
+ }
+
+ // check if the target account follows the requesting account
+ followedBy, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
+ }
+ r.FollowedBy = followedBy
+
+ // check if the requesting account blocks the target account
+ blocking, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
+ }
+ r.Blocking = blocking
+
+ // check if the target account blocks the requesting account
+ blockedBy, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ }
+ r.BlockedBy = blockedBy
+
+ // check if there's a pending following request from requesting account to target account
+ requested, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
+ if err != nil {
+ return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
+ }
+ r.Requested = requested
+
+ return r, nil
+}
+
func (ps *postgresService) StatusVisible(targetStatus *gtsmodel.Status, targetAccount *gtsmodel.Account, requestingAccount *gtsmodel.Account, relevantAccounts *gtsmodel.RelevantAccounts) (bool, error) {
l := ps.log.WithField("func", "StatusVisible")
@@ -853,6 +921,10 @@ func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccoun
return ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
+func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
+ return ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
+}
+
func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) {
// make sure account 1 follows account 2
f1, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
@@ -1036,6 +1108,11 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.
*/
func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
+ ogAccount := &gtsmodel.Account{}
+ if err := ps.conn.Model(ogAccount).Where("id = ?", originAccountID).Select(); err != nil {
+ return nil, err
+ }
+
menchies := []*gtsmodel.Mention{}
for _, a := range targetAccounts {
// A mentioned account looks like "@test@example.org" or just "@test" for a local account
@@ -1093,9 +1170,13 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
// id, createdAt and updatedAt will be populated by the db, so we have everything we need!
menchies = append(menchies, &gtsmodel.Mention{
- StatusID: statusID,
- OriginAccountID: originAccountID,
- TargetAccountID: mentionedAccount.ID,
+ StatusID: statusID,
+ OriginAccountID: ogAccount.ID,
+ OriginAccountURI: ogAccount.URI,
+ TargetAccountID: mentionedAccount.ID,
+ NameString: a,
+ MentionedAccountURI: mentionedAccount.URI,
+ GTSAccount: mentionedAccount,
})
}
return menchies, nil