summaryrefslogtreecommitdiff
path: root/internal/db/bundb/status.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/status.go')
-rw-r--r--internal/db/bundb/status.go114
1 files changed, 41 insertions, 73 deletions
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index 2019322ac..1d5acf0fc 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -24,7 +24,6 @@ import (
"errors"
"time"
- "github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -34,38 +33,8 @@ import (
type statusDB struct {
config *config.Config
- conn *bun.DB
- log *logrus.Logger
- cache cache.Cache
-}
-
-func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
- if s.cache == nil {
- s.cache = cache.New()
- }
-
- if err := s.cache.Store(id, status); err != nil {
- s.log.Panicf("statusDB: error storing in cache: %s", err)
- }
-}
-
-func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
- if s.cache == nil {
- s.cache = cache.New()
- return nil, false
- }
-
- sI, err := s.cache.Fetch(id)
- if err != nil || sI == nil {
- return nil, false
- }
-
- status, ok := sI.(*gtsmodel.Status)
- if !ok {
- s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
- }
-
- return status, true
+ conn *DBConn
+ cache *cache.StatusCache
}
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
@@ -84,7 +53,9 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
if status.InReplyToID != "" && status.InReplyTo == nil {
- if inReplyTo, cached := s.statusCached(status.InReplyToID); cached {
+ // TODO: do we want to keep this possibly recursive strategy?
+
+ if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached {
status.InReplyTo = inReplyTo
} else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
status.InReplyTo = inReplyTo
@@ -92,7 +63,9 @@ func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Sta
}
if status.BoostOfID != "" && status.BoostOf == nil {
- if boostOf, cached := s.statusCached(status.BoostOfID); cached {
+ // TODO: do we want to keep this possibly recursive strategy?
+
+ if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached {
status.BoostOf = boostOf
} else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
status.BoostOf = boostOf
@@ -112,29 +85,26 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
}
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(id); cached {
+ if status, cached := s.cache.GetByID(id); cached {
return status, nil
}
- status := new(gtsmodel.Status)
+ status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("status.id = ?", id)
- err := processErrorResponse(q.Scan(ctx))
+ err := q.Scan(ctx)
if err != nil {
- return nil, err
- }
-
- if status != nil {
- s.cacheStatus(id, status)
+ return nil, s.conn.ProcessError(err)
}
- return s.getAttachedStatuses(ctx, status), err
+ s.cache.Put(status)
+ return s.getAttachedStatuses(ctx, status), nil
}
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(uri); cached {
+ if status, cached := s.cache.GetByURI(uri); cached {
return status, nil
}
@@ -143,38 +113,32 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
- err := processErrorResponse(q.Scan(ctx))
+ err := q.Scan(ctx)
if err != nil {
- return nil, err
+ return nil, s.conn.ProcessError(err)
}
- if status != nil {
- s.cacheStatus(uri, status)
- }
-
- return s.getAttachedStatuses(ctx, status), err
+ s.cache.Put(status)
+ return s.getAttachedStatuses(ctx, status), nil
}
-func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.statusCached(uri); cached {
+func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
+ if status, cached := s.cache.GetByURL(url); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
- Where("LOWER(status.url) = LOWER(?)", uri)
+ Where("LOWER(status.url) = LOWER(?)", url)
- err := processErrorResponse(q.Scan(ctx))
+ err := q.Scan(ctx)
if err != nil {
- return nil, err
- }
-
- if status != nil {
- s.cacheStatus(uri, status)
+ return nil, s.conn.ProcessError(err)
}
- return s.getAttachedStatuses(ctx, status), err
+ s.cache.Put(status)
+ return s.getAttachedStatuses(ctx, status), nil
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
@@ -213,14 +177,12 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
_, err := tx.NewInsert().Model(status).Exec(ctx)
return err
}
-
- return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
+ return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction))
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
parents := []*gtsmodel.Status{}
s.statusParent(ctx, status, &parents, onlyDirect)
-
return parents, nil
}
@@ -318,7 +280,7 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status,
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
- return exists(ctx, q)
+ return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@@ -328,7 +290,7 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta
Where("boost_of_id = ?", status.ID).
Where("account_id = ?", accountID)
- return exists(ctx, q)
+ return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@@ -338,7 +300,7 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status,
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
- return exists(ctx, q)
+ return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@@ -348,7 +310,7 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
- return exists(ctx, q)
+ return s.conn.Exists(ctx, q)
}
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
@@ -357,8 +319,11 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)
q := s.newFaveQ(&faves).
Where("status_id = ?", status.ID)
- err := processErrorResponse(q.Scan(ctx))
- return faves, err
+ err := q.Scan(ctx)
+ if err != nil {
+ return nil, s.conn.ProcessError(err)
+ }
+ return faves, nil
}
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
@@ -367,6 +332,9 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status
q := s.newStatusQ(&reblogs).
Where("boost_of_id = ?", status.ID)
- err := processErrorResponse(q.Scan(ctx))
- return reblogs, err
+ err := q.Scan(ctx)
+ if err != nil {
+ return nil, s.conn.ProcessError(err)
+ }
+ return reblogs, nil
}