diff options
author | 2021-08-29 15:41:41 +0100 | |
---|---|---|
committer | 2021-08-29 16:41:41 +0200 | |
commit | ed462245730bd7832019bd43e0bc1c9d1c055e8e (patch) | |
tree | 1caad78ea6aabf5ea93c93a8ade97176b4889500 /internal/db/bundb/relationship.go | |
parent | Mention fixup (#167) (diff) | |
download | gotosocial-ed462245730bd7832019bd43e0bc1c9d1c055e8e.tar.xz |
Add SQLite support, fix un-thread-safe DB caches, small performance f… (#172)
* Add SQLite support, fix un-thread-safe DB caches, small performance fixes
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
* add SQLite licenses to README
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
* appease the linter, and fix my dumbass-ery
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
* make requested changes
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
* add back comment
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
Diffstat (limited to 'internal/db/bundb/relationship.go')
-rw-r--r-- | internal/db/bundb/relationship.go | 53 |
1 files changed, 27 insertions, 26 deletions
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 95426f122..56b752593 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -23,7 +23,6 @@ import ( "database/sql" "fmt" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -32,8 +31,7 @@ import ( type relationshipDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger + conn *DBConn } func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery { @@ -66,7 +64,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account Where("account_id = ?", account2) } - return exists(ctx, q) + return r.conn.Exists(ctx, q) } func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) { @@ -76,9 +74,11 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 Where("block.account_id = ?", account1). Where("block.target_account_id = ?", account2) - err := processErrorResponse(q.Scan(ctx)) - - return block, err + err := q.Scan(ctx) + if err != nil { + return nil, r.conn.ProcessError(err) + } + return block, nil } func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { @@ -176,7 +176,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode Where("target_account_id = ?", targetAccount.ID). Limit(1) - return exists(ctx, q) + return r.conn.Exists(ctx, q) } func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { @@ -190,7 +190,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g Where("account_id = ?", sourceAccount.ID). Where("target_account_id = ?", targetAccount.ID) - return exists(ctx, q) + return r.conn.Exists(ctx, q) } func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { @@ -201,13 +201,13 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod // make sure account 1 follows account 2 f1, err := r.IsFollowing(ctx, account1, account2) if err != nil { - return false, processErrorResponse(err) + return false, err } // make sure account 2 follows account 1 f2, err := r.IsFollowing(ctx, account2, account1) if err != nil { - return false, processErrorResponse(err) + return false, err } return f1 && f2, nil @@ -222,7 +222,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI Where("account_id = ?", originAccountID). Where("target_account_id = ?", targetAccountID). Scan(ctx); err != nil { - return nil, processErrorResponse(err) + return nil, r.conn.ProcessError(err) } // create a new follow to 'replace' the request with @@ -239,7 +239,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI Model(follow). On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI). Exec(ctx); err != nil { - return nil, processErrorResponse(err) + return nil, r.conn.ProcessError(err) } // now remove the follow request @@ -249,7 +249,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI Where("account_id = ?", originAccountID). Where("target_account_id = ?", targetAccountID). Exec(ctx); err != nil { - return nil, processErrorResponse(err) + return nil, r.conn.ProcessError(err) } return follow, nil @@ -261,9 +261,11 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID q := r.newFollowQ(&followRequests). Where("target_account_id = ?", accountID) - err := processErrorResponse(q.Scan(ctx)) - - return followRequests, err + err := q.Scan(ctx) + if err != nil { + return nil, r.conn.ProcessError(err) + } + return followRequests, nil } func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) { @@ -272,9 +274,11 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string q := r.newFollowQ(&follows). Where("account_id = ?", accountID) - err := processErrorResponse(q.Scan(ctx)) - - return follows, err + err := q.Scan(ctx) + if err != nil { + return nil, r.conn.ProcessError(err) + } + return follows, nil } func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { @@ -286,7 +290,6 @@ func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID stri } func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { - follows := []*gtsmodel.Follow{} q := r.conn. @@ -302,11 +305,9 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str q = q.Where("target_account_id = ?", accountID) } - if err := q.Scan(ctx); err != nil { - if err == sql.ErrNoRows { - return follows, nil - } - return nil, processErrorResponse(err) + err := q.Scan(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, r.conn.ProcessError(err) } return follows, nil } |