summaryrefslogtreecommitdiff
path: root/internal/db/bundb/status.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/status.go')
-rw-r--r--internal/db/bundb/status.go145
1 files changed, 66 insertions, 79 deletions
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index 1d5acf0fc..9464cfadf 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -21,7 +21,6 @@ package bundb
import (
"container/list"
"context"
- "errors"
"time"
"github.com/superseriousbusiness/gotosocial/internal/cache"
@@ -35,6 +34,11 @@ type statusDB struct {
config *config.Config
conn *DBConn
cache *cache.StatusCache
+
+ // TODO: keep method definitions in same place but instead have receiver
+ // all point to one single "db" type, so they can all share methods
+ // and caches where necessary
+ accounts *accountDB
}
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
@@ -51,30 +55,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
Relation("CreatedWithApplication")
}
-func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
- if status.InReplyToID != "" && status.InReplyTo == nil {
- // TODO: do we want to keep this possibly recursive strategy?
-
- if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached {
- status.InReplyTo = inReplyTo
- } else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
- status.InReplyTo = inReplyTo
- }
- }
-
- if status.BoostOfID != "" && status.BoostOf == nil {
- // TODO: do we want to keep this possibly recursive strategy?
-
- if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached {
- status.BoostOf = boostOf
- } else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
- status.BoostOf = boostOf
- }
- }
-
- return status
-}
-
func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
@@ -85,64 +65,79 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
}
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.cache.GetByID(id); cached {
- return status, nil
- }
-
- status := &gtsmodel.Status{}
-
- q := s.newStatusQ(status).
- Where("status.id = ?", id)
-
- err := q.Scan(ctx)
- if err != nil {
- return nil, s.conn.ProcessError(err)
- }
-
- s.cache.Put(status)
- return s.getAttachedStatuses(ctx, status), nil
+ return s.getStatus(
+ ctx,
+ func() (*gtsmodel.Status, bool) {
+ return s.cache.GetByID(id)
+ },
+ func(status *gtsmodel.Status) error {
+ return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx)
+ },
+ )
}
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.cache.GetByURI(uri); cached {
- return status, nil
- }
-
- status := &gtsmodel.Status{}
+ return s.getStatus(
+ ctx,
+ func() (*gtsmodel.Status, bool) {
+ return s.cache.GetByURI(uri)
+ },
+ func(status *gtsmodel.Status) error {
+ return s.newStatusQ(status).Where("LOWER(status.uri) = LOWER(?)", uri).Scan(ctx)
+ },
+ )
+}
- q := s.newStatusQ(status).
- Where("LOWER(status.uri) = LOWER(?)", uri)
+func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
+ return s.getStatus(
+ ctx,
+ func() (*gtsmodel.Status, bool) {
+ return s.cache.GetByURL(url)
+ },
+ func(status *gtsmodel.Status) error {
+ return s.newStatusQ(status).Where("LOWER(status.url) = LOWER(?)", url).Scan(ctx)
+ },
+ )
+}
- err := q.Scan(ctx)
- if err != nil {
- return nil, s.conn.ProcessError(err)
- }
+func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) {
+ // Attempt to fetch cached status
+ status, cached := cacheGet()
- s.cache.Put(status)
- return s.getAttachedStatuses(ctx, status), nil
-}
+ if !cached {
+ status = &gtsmodel.Status{}
-func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
- if status, cached := s.cache.GetByURL(url); cached {
- return status, nil
- }
+ // Not cached! Perform database query
+ err := dbQuery(status)
+ if err != nil {
+ return nil, s.conn.ProcessError(err)
+ }
- status := &gtsmodel.Status{}
+ // If there is boosted, fetch from DB also
+ if status.BoostOfID != "" {
+ boostOf, err := s.GetStatusByID(ctx, status.BoostOfID)
+ if err == nil {
+ status.BoostOf = boostOf
+ }
+ }
- q := s.newStatusQ(status).
- Where("LOWER(status.url) = LOWER(?)", url)
+ // Place in the cache
+ s.cache.Put(status)
+ }
- err := q.Scan(ctx)
+ // Set the status author account
+ author, err := s.accounts.GetAccountByID(ctx, status.AccountID)
if err != nil {
- return nil, s.conn.ProcessError(err)
+ return nil, err
}
- s.cache.Put(status)
- return s.getAttachedStatuses(ctx, status), nil
+ // Return the prepared status
+ status.Account = author
+ return status, nil
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
- transaction := func(ctx context.Context, tx bun.Tx) error {
+ return s.conn.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.NewInsert().Model(&gtsmodel.StatusToEmoji{
@@ -174,10 +169,10 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
}
}
+ // Finally, insert the status
_, err := tx.NewInsert().Model(status).Exec(ctx)
return err
- }
- return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction))
+ })
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
@@ -210,12 +205,8 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
// only append children, not the overall parent status
+ entry := e.Value.(*gtsmodel.Status)
if entry.ID != status.ID {
children = append(children, entry)
}
@@ -242,11 +233,7 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,
for _, child := range immediateChildren {
insertLoop:
for e := foundStatuses.Front(); e != nil; e = e.Next() {
- entry, ok := e.Value.(*gtsmodel.Status)
- if !ok {
- panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
- }
-
+ entry := e.Value.(*gtsmodel.Status)
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
foundStatuses.InsertAfter(child, e)
break insertLoop