diff options
Diffstat (limited to 'internal/db/bundb/status.go')
-rw-r--r-- | internal/db/bundb/status.go | 251 |
1 files changed, 132 insertions, 119 deletions
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index bc72c2849..b4ae40607 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -25,7 +25,7 @@ import ( "errors" "time" - "github.com/superseriousbusiness/gotosocial/internal/cache" + "codeberg.org/gruf/go-cache/v3/result" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -33,15 +33,28 @@ import ( ) type statusDB struct { - conn *DBConn - cache *cache.StatusCache - - // TODO: keep method definitions in same place but instead have receiver - // all point to one single "db" type, so they can all share methods - // and caches where necessary + conn *DBConn + cache *result.Cache[*gtsmodel.Status] accounts *accountDB } +func (s *statusDB) init() { + // Initialize status result cache + s.cache = result.NewSized([]result.Lookup{ + {Name: "ID"}, + {Name: "URI"}, + {Name: "URL"}, + }, func(s1 *gtsmodel.Status) *gtsmodel.Status { + s2 := new(gtsmodel.Status) + *s2 = *s1 + return s2 + }, 1000) + + // Set cache TTL and start sweep routine + s.cache.SetTTL(time.Minute*5, false) + s.cache.Start(time.Second * 10) +} + func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { return s.conn. NewSelect(). @@ -68,61 +81,62 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery { func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { return s.getStatus( ctx, - func() (*gtsmodel.Status, bool) { - return s.cache.GetByID(id) - }, + "ID", func(status *gtsmodel.Status) error { return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx) }, + id, ) } func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { return s.getStatus( ctx, - func() (*gtsmodel.Status, bool) { - return s.cache.GetByURI(uri) - }, + "URI", func(status *gtsmodel.Status) error { return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx) }, + uri, ) } func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { return s.getStatus( ctx, - func() (*gtsmodel.Status, bool) { - return s.cache.GetByURL(url) - }, + "URL", func(status *gtsmodel.Status) error { return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx) }, + url, ) } -func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) { - // Attempt to fetch cached status - status, cached := cacheGet() - - if !cached { - status = >smodel.Status{} +func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, db.Error) { + // Fetch status from database cache with loader callback + status, err := s.cache.Load(lookup, func() (*gtsmodel.Status, error) { + var status gtsmodel.Status // Not cached! Perform database query - if err := dbQuery(status); err != nil { + if err := dbQuery(&status); err != nil { return nil, s.conn.ProcessError(err) } // If there is boosted, fetch from DB also if status.BoostOfID != "" { - boostOf, err := s.GetStatusByID(ctx, status.BoostOfID) - if err == nil { - status.BoostOf = boostOf + status.BoostOf = >smodel.Status{} + err := s.newStatusQ(status.BoostOf). + Where("? = ?", bun.Ident("status.id"), status.BoostOfID). + Scan(ctx) + if err != nil { + return nil, s.conn.ProcessError(err) } } - // Place in the cache - s.cache.Put(status) + return &status, nil + }, keyParts...) + if err != nil { + // error already processed + return nil, err } // Set the status author account @@ -137,73 +151,66 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { - err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { - // create links between this status and any emojis it uses - for _, i := range status.EmojiIDs { - if _, err := tx. - NewInsert(). - Model(>smodel.StatusToEmoji{ - StatusID: status.ID, - EmojiID: i, - }).Exec(ctx); err != nil { - err = s.conn.errProc(err) - if !errors.Is(err, db.ErrAlreadyExists) { - return err + return s.cache.Store(status, func() error { + // It is safe to run this database transaction within cache.Store + // as the cache does not attempt a mutex lock until AFTER hook. + // + return s.conn.RunInTx(ctx, func(tx bun.Tx) error { + // create links between this status and any emojis it uses + for _, i := range status.EmojiIDs { + if _, err := tx. + NewInsert(). + Model(>smodel.StatusToEmoji{ + StatusID: status.ID, + EmojiID: i, + }).Exec(ctx); err != nil { + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - } - // create links between this status and any tags it uses - for _, i := range status.TagIDs { - if _, err := tx. - NewInsert(). - Model(>smodel.StatusToTag{ - StatusID: status.ID, - TagID: i, - }).Exec(ctx); err != nil { - err = s.conn.errProc(err) - if !errors.Is(err, db.ErrAlreadyExists) { - return err + // create links between this status and any tags it uses + for _, i := range status.TagIDs { + if _, err := tx. + NewInsert(). + Model(>smodel.StatusToTag{ + StatusID: status.ID, + TagID: i, + }).Exec(ctx); err != nil { + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - } - // change the status ID of the media attachments to the new status - for _, a := range status.Attachments { - a.StatusID = status.ID - a.UpdatedAt = time.Now() - if _, err := tx. - NewUpdate(). - Model(a). - Where("? = ?", bun.Ident("media_attachment.id"), a.ID). - Exec(ctx); err != nil { - err = s.conn.errProc(err) - if !errors.Is(err, db.ErrAlreadyExists) { - return err + // change the status ID of the media attachments to the new status + for _, a := range status.Attachments { + a.StatusID = status.ID + a.UpdatedAt = time.Now() + if _, err := tx. + NewUpdate(). + Model(a). + Where("? = ?", bun.Ident("media_attachment.id"), a.ID). + Exec(ctx); err != nil { + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - } - // Finally, insert the status - if _, err := tx. - NewInsert(). - Model(status). - Exec(ctx); err != nil { + // Finally, insert the status + _, err := tx.NewInsert().Model(status).Exec(ctx) return err - } - - return nil + }) }) - if err != nil { - return s.conn.ProcessError(err) - } - - s.cache.Put(status) - return nil } -func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) { - err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { +func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) db.Error { + if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { // create links between this status and any emojis it uses for _, i := range status.EmojiIDs { if _, err := tx. @@ -212,7 +219,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* StatusID: status.ID, EmojiID: i, }).Exec(ctx); err != nil { - err = s.conn.errProc(err) + err = s.conn.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -227,14 +234,14 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* StatusID: status.ID, TagID: i, }).Exec(ctx); err != nil { - err = s.conn.errProc(err) + err = s.conn.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } } } - // change the status ID of the media attachments to this status + // change the status ID of the media attachments to the new status for _, a := range status.Attachments { a.StatusID = status.ID a.UpdatedAt = time.Now() @@ -243,31 +250,31 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* Model(a). Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Exec(ctx); err != nil { - return err + err = s.conn.ProcessError(err) + if !errors.Is(err, db.ErrAlreadyExists) { + return err + } } } - // Finally, update the status itself - if _, err := tx. + // Finally, insert the status + _, err := tx. NewUpdate(). Model(status). Where("? = ?", bun.Ident("status.id"), status.ID). - Exec(ctx); err != nil { - return err - } - - return nil - }) - if err != nil { - return nil, s.conn.ProcessError(err) + Exec(ctx) + return err + }); err != nil { + return err } - s.cache.Put(status) - return status, nil + // Drop any old value from cache by this ID + s.cache.Invalidate("ID", status.ID) + return nil } func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { - err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { + if err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { // delete links between this status and any emojis it uses if _, err := tx. NewDelete(). @@ -296,36 +303,41 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { } return nil - }) - if err != nil { - return s.conn.ProcessError(err) + }); err != nil { + return err } - s.cache.Invalidate(id) + // Drop any old value from cache by this ID + s.cache.Invalidate("ID", id) return nil } func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { - parents := []*gtsmodel.Status{} - s.statusParent(ctx, status, &parents, onlyDirect) - return parents, nil -} - -func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { - if status.InReplyToID == "" { - return + if onlyDirect { + // Only want the direct parent, no further than first level + parent, err := s.GetStatusByID(ctx, status.InReplyToID) + if err != nil { + return nil, err + } + return []*gtsmodel.Status{parent}, nil } - parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID) - if err == nil { - *foundStatuses = append(*foundStatuses, parentStatus) - } + var parents []*gtsmodel.Status - if onlyDirect { - return + for id := status.InReplyToID; id != ""; { + parent, err := s.GetStatusByID(ctx, id) + if err != nil { + return nil, err + } + + // Append parent to slice + parents = append(parents, parent) + + // Set the next parent ID + id = parent.InReplyToID } - s.statusParent(ctx, parentStatus, foundStatuses, false) + return parents, nil } func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { @@ -350,7 +362,7 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu } func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { - childIDs := []string{} + var childIDs []string q := s.conn. NewSelect(). @@ -471,6 +483,7 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) if err := q.Scan(ctx); err != nil { return nil, s.conn.ProcessError(err) } + return faves, nil } |