summaryrefslogtreecommitdiff
path: root/internal/db/bundb/relationship.go
diff options
context:
space:
mode:
authorLibravatar kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>2021-08-29 15:41:41 +0100
committerLibravatar GitHub <noreply@github.com>2021-08-29 16:41:41 +0200
commited462245730bd7832019bd43e0bc1c9d1c055e8e (patch)
tree1caad78ea6aabf5ea93c93a8ade97176b4889500 /internal/db/bundb/relationship.go
parentMention fixup (#167) (diff)
downloadgotosocial-ed462245730bd7832019bd43e0bc1c9d1c055e8e.tar.xz
Add SQLite support, fix un-thread-safe DB caches, small performance f… (#172)
* Add SQLite support, fix un-thread-safe DB caches, small performance fixes Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add SQLite licenses to README Signed-off-by: kim (grufwub) <grufwub@gmail.com> * appease the linter, and fix my dumbass-ery Signed-off-by: kim (grufwub) <grufwub@gmail.com> * make requested changes Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add back comment Signed-off-by: kim (grufwub) <grufwub@gmail.com>
Diffstat (limited to 'internal/db/bundb/relationship.go')
-rw-r--r--internal/db/bundb/relationship.go53
1 files changed, 27 insertions, 26 deletions
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index 95426f122..56b752593 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -23,7 +23,6 @@ import (
"database/sql"
"fmt"
- "github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -32,8 +31,7 @@ import (
type relationshipDB struct {
config *config.Config
- conn *bun.DB
- log *logrus.Logger
+ conn *DBConn
}
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {
@@ -66,7 +64,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account
Where("account_id = ?", account2)
}
- return exists(ctx, q)
+ return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
@@ -76,9 +74,11 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
Where("block.account_id = ?", account1).
Where("block.target_account_id = ?", account2)
- err := processErrorResponse(q.Scan(ctx))
-
- return block, err
+ err := q.Scan(ctx)
+ if err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+ return block, nil
}
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
@@ -176,7 +176,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
Where("target_account_id = ?", targetAccount.ID).
Limit(1)
- return exists(ctx, q)
+ return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
@@ -190,7 +190,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
- return exists(ctx, q)
+ return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
@@ -201,13 +201,13 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
// make sure account 1 follows account 2
f1, err := r.IsFollowing(ctx, account1, account2)
if err != nil {
- return false, processErrorResponse(err)
+ return false, err
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(ctx, account2, account1)
if err != nil {
- return false, processErrorResponse(err)
+ return false, err
}
return f1 && f2, nil
@@ -222,7 +222,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Scan(ctx); err != nil {
- return nil, processErrorResponse(err)
+ return nil, r.conn.ProcessError(err)
}
// create a new follow to 'replace' the request with
@@ -239,7 +239,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Model(follow).
On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
Exec(ctx); err != nil {
- return nil, processErrorResponse(err)
+ return nil, r.conn.ProcessError(err)
}
// now remove the follow request
@@ -249,7 +249,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Exec(ctx); err != nil {
- return nil, processErrorResponse(err)
+ return nil, r.conn.ProcessError(err)
}
return follow, nil
@@ -261,9 +261,11 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID
q := r.newFollowQ(&followRequests).
Where("target_account_id = ?", accountID)
- err := processErrorResponse(q.Scan(ctx))
-
- return followRequests, err
+ err := q.Scan(ctx)
+ if err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+ return followRequests, nil
}
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) {
@@ -272,9 +274,11 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
q := r.newFollowQ(&follows).
Where("account_id = ?", accountID)
- err := processErrorResponse(q.Scan(ctx))
-
- return follows, err
+ err := q.Scan(ctx)
+ if err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+ return follows, nil
}
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
@@ -286,7 +290,6 @@ func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID stri
}
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
-
follows := []*gtsmodel.Follow{}
q := r.conn.
@@ -302,11 +305,9 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
q = q.Where("target_account_id = ?", accountID)
}
- if err := q.Scan(ctx); err != nil {
- if err == sql.ErrNoRows {
- return follows, nil
- }
- return nil, processErrorResponse(err)
+ err := q.Scan(ctx)
+ if err != nil && err != sql.ErrNoRows {
+ return nil, r.conn.ProcessError(err)
}
return follows, nil
}