diff options
Diffstat (limited to 'internal/db/bundb/status.go')
-rw-r--r-- | internal/db/bundb/status.go | 145 |
1 files changed, 66 insertions, 79 deletions
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 1d5acf0fc..9464cfadf 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -21,7 +21,6 @@ package bundb import ( "container/list" "context" - "errors" "time" "github.com/superseriousbusiness/gotosocial/internal/cache" @@ -35,6 +34,11 @@ type statusDB struct { config *config.Config conn *DBConn cache *cache.StatusCache + + // TODO: keep method definitions in same place but instead have receiver + // all point to one single "db" type, so they can all share methods + // and caches where necessary + accounts *accountDB } func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { @@ -51,30 +55,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { Relation("CreatedWithApplication") } -func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status { - if status.InReplyToID != "" && status.InReplyTo == nil { - // TODO: do we want to keep this possibly recursive strategy? - - if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached { - status.InReplyTo = inReplyTo - } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil { - status.InReplyTo = inReplyTo - } - } - - if status.BoostOfID != "" && status.BoostOf == nil { - // TODO: do we want to keep this possibly recursive strategy? - - if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached { - status.BoostOf = boostOf - } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil { - status.BoostOf = boostOf - } - } - - return status -} - func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { return s.conn. NewSelect(). @@ -85,64 +65,79 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { } func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { - if status, cached := s.cache.GetByID(id); cached { - return status, nil - } - - status := >smodel.Status{} - - q := s.newStatusQ(status). - Where("status.id = ?", id) - - err := q.Scan(ctx) - if err != nil { - return nil, s.conn.ProcessError(err) - } - - s.cache.Put(status) - return s.getAttachedStatuses(ctx, status), nil + return s.getStatus( + ctx, + func() (*gtsmodel.Status, bool) { + return s.cache.GetByID(id) + }, + func(status *gtsmodel.Status) error { + return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx) + }, + ) } func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { - if status, cached := s.cache.GetByURI(uri); cached { - return status, nil - } - - status := >smodel.Status{} + return s.getStatus( + ctx, + func() (*gtsmodel.Status, bool) { + return s.cache.GetByURI(uri) + }, + func(status *gtsmodel.Status) error { + return s.newStatusQ(status).Where("LOWER(status.uri) = LOWER(?)", uri).Scan(ctx) + }, + ) +} - q := s.newStatusQ(status). - Where("LOWER(status.uri) = LOWER(?)", uri) +func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { + return s.getStatus( + ctx, + func() (*gtsmodel.Status, bool) { + return s.cache.GetByURL(url) + }, + func(status *gtsmodel.Status) error { + return s.newStatusQ(status).Where("LOWER(status.url) = LOWER(?)", url).Scan(ctx) + }, + ) +} - err := q.Scan(ctx) - if err != nil { - return nil, s.conn.ProcessError(err) - } +func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) { + // Attempt to fetch cached status + status, cached := cacheGet() - s.cache.Put(status) - return s.getAttachedStatuses(ctx, status), nil -} + if !cached { + status = >smodel.Status{} -func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { - if status, cached := s.cache.GetByURL(url); cached { - return status, nil - } + // Not cached! Perform database query + err := dbQuery(status) + if err != nil { + return nil, s.conn.ProcessError(err) + } - status := >smodel.Status{} + // If there is boosted, fetch from DB also + if status.BoostOfID != "" { + boostOf, err := s.GetStatusByID(ctx, status.BoostOfID) + if err == nil { + status.BoostOf = boostOf + } + } - q := s.newStatusQ(status). - Where("LOWER(status.url) = LOWER(?)", url) + // Place in the cache + s.cache.Put(status) + } - err := q.Scan(ctx) + // Set the status author account + author, err := s.accounts.GetAccountByID(ctx, status.AccountID) if err != nil { - return nil, s.conn.ProcessError(err) + return nil, err } - s.cache.Put(status) - return s.getAttachedStatuses(ctx, status), nil + // Return the prepared status + status.Account = author + return status, nil } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { - transaction := func(ctx context.Context, tx bun.Tx) error { + return s.conn.RunInTx(ctx, func(tx bun.Tx) error { // create links between this status and any emojis it uses for _, i := range status.EmojiIDs { if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ @@ -174,10 +169,10 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er } } + // Finally, insert the status _, err := tx.NewInsert().Model(status).Exec(ctx) return err - } - return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction)) + }) } func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { @@ -210,12 +205,8 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu children := []*gtsmodel.Status{} for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - // only append children, not the overall parent status + entry := e.Value.(*gtsmodel.Status) if entry.ID != status.ID { children = append(children, entry) } @@ -242,11 +233,7 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, for _, child := range immediateChildren { insertLoop: for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status")) - } - + entry := e.Value.(*gtsmodel.Status) if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { foundStatuses.InsertAfter(child, e) break insertLoop |