diff options
Diffstat (limited to 'internal/db/bundb/status.go')
-rw-r--r-- | internal/db/bundb/status.go | 114 |
1 files changed, 41 insertions, 73 deletions
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 2019322ac..1d5acf0fc 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -24,7 +24,6 @@ import ( "errors" "time" - "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -34,38 +33,8 @@ import ( type statusDB struct { config *config.Config - conn *bun.DB - log *logrus.Logger - cache cache.Cache -} - -func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) { - if s.cache == nil { - s.cache = cache.New() - } - - if err := s.cache.Store(id, status); err != nil { - s.log.Panicf("statusDB: error storing in cache: %s", err) - } -} - -func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) { - if s.cache == nil { - s.cache = cache.New() - return nil, false - } - - sI, err := s.cache.Fetch(id) - if err != nil || sI == nil { - return nil, false - } - - status, ok := sI.(*gtsmodel.Status) - if !ok { - s.log.Panicf("statusDB: cached interface with key %s was not a status", id) - } - - return status, true + conn *DBConn + cache *cache.StatusCache } func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { @@ -84,7 +53,9 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status { if status.InReplyToID != "" && status.InReplyTo == nil { - if inReplyTo, cached := s.statusCached(status.InReplyToID); cached { + // TODO: do we want to keep this possibly recursive strategy? + + if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached { status.InReplyTo = inReplyTo } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil { status.InReplyTo = inReplyTo @@ -92,7 +63,9 @@ func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Sta } if status.BoostOfID != "" && status.BoostOf == nil { - if boostOf, cached := s.statusCached(status.BoostOfID); cached { + // TODO: do we want to keep this possibly recursive strategy? + + if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached { status.BoostOf = boostOf } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil { status.BoostOf = boostOf @@ -112,29 +85,26 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { } func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(id); cached { + if status, cached := s.cache.GetByID(id); cached { return status, nil } - status := new(gtsmodel.Status) + status := >smodel.Status{} q := s.newStatusQ(status). Where("status.id = ?", id) - err := processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) if err != nil { - return nil, err - } - - if status != nil { - s.cacheStatus(id, status) + return nil, s.conn.ProcessError(err) } - return s.getAttachedStatuses(ctx, status), err + s.cache.Put(status) + return s.getAttachedStatuses(ctx, status), nil } func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(uri); cached { + if status, cached := s.cache.GetByURI(uri); cached { return status, nil } @@ -143,38 +113,32 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St q := s.newStatusQ(status). Where("LOWER(status.uri) = LOWER(?)", uri) - err := processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) if err != nil { - return nil, err + return nil, s.conn.ProcessError(err) } - if status != nil { - s.cacheStatus(uri, status) - } - - return s.getAttachedStatuses(ctx, status), err + s.cache.Put(status) + return s.getAttachedStatuses(ctx, status), nil } -func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.statusCached(uri); cached { +func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { + if status, cached := s.cache.GetByURL(url); cached { return status, nil } status := >smodel.Status{} q := s.newStatusQ(status). - Where("LOWER(status.url) = LOWER(?)", uri) + Where("LOWER(status.url) = LOWER(?)", url) - err := processErrorResponse(q.Scan(ctx)) + err := q.Scan(ctx) if err != nil { - return nil, err - } - - if status != nil { - s.cacheStatus(uri, status) + return nil, s.conn.ProcessError(err) } - return s.getAttachedStatuses(ctx, status), err + s.cache.Put(status) + return s.getAttachedStatuses(ctx, status), nil } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { @@ -213,14 +177,12 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er _, err := tx.NewInsert().Model(status).Exec(ctx) return err } - - return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction)) + return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction)) } func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { parents := []*gtsmodel.Status{} s.statusParent(ctx, status, &parents, onlyDirect) - return parents, nil } @@ -318,7 +280,7 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, Where("status_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { @@ -328,7 +290,7 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta Where("boost_of_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { @@ -338,7 +300,7 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, Where("status_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { @@ -348,7 +310,7 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St Where("status_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { @@ -357,8 +319,11 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) q := s.newFaveQ(&faves). Where("status_id = ?", status.ID) - err := processErrorResponse(q.Scan(ctx)) - return faves, err + err := q.Scan(ctx) + if err != nil { + return nil, s.conn.ProcessError(err) + } + return faves, nil } func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { @@ -367,6 +332,9 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status q := s.newStatusQ(&reblogs). Where("boost_of_id = ?", status.ID) - err := processErrorResponse(q.Scan(ctx)) - return reblogs, err + err := q.Scan(ctx) + if err != nil { + return nil, s.conn.ProcessError(err) + } + return reblogs, nil } |