diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/status.go | 82 | ||||
-rw-r--r-- | internal/db/bundb/status_test.go | 14 | ||||
-rw-r--r-- | internal/db/bundb/util.go | 3 | ||||
-rw-r--r-- | internal/db/status.go | 10 |
4 files changed, 37 insertions, 72 deletions
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 80346412c..dd161e1ec 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -18,7 +18,6 @@ package bundb import ( - "container/list" "context" "errors" "time" @@ -515,16 +514,7 @@ func (s *statusDB) GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([ return s.GetStatusesByIDs(ctx, statusIDs) } -func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) { - if onlyDirect { - // Only want the direct parent, no further than first level - parent, err := s.GetStatusByID(ctx, status.InReplyToID) - if err != nil { - return nil, err - } - return []*gtsmodel.Status{parent}, nil - } - +func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) { var parents []*gtsmodel.Status for id := status.InReplyToID; id != ""; { @@ -533,7 +523,7 @@ func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status return nil, err } - // Append parent to slice + // Append parent status to slice parents = append(parents, parent) // Set the next parent ID @@ -543,65 +533,31 @@ func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status return parents, nil } -func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) { - foundStatuses := &list.List{} - foundStatuses.PushFront(status) - s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID) - - children := []*gtsmodel.Status{} - for e := foundStatuses.Front(); e != nil; e = e.Next() { - // only append children, not the overall parent status - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - log.Panic(ctx, "found status could not be asserted to *gtsmodel.Status") - } - - if entry.ID != status.ID { - children = append(children, entry) - } +func (s *statusDB) GetStatusChildren(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) { + // Get all replies for the currently set status. + replies, err := s.GetStatusReplies(ctx, statusID) + if err != nil { + return nil, err } - return children, nil -} + // Make estimated preallocation based on direct replies. + children := make([]*gtsmodel.Status, 0, len(replies)*2) -func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { - childIDs, err := s.getStatusReplyIDs(ctx, status.ID) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - log.Errorf(ctx, "error getting status %s children: %v", status.ID, err) - return - } + for _, status := range replies { + // Append status to children. + children = append(children, status) - for _, id := range childIDs { - if id <= minID { - continue - } - - // Fetch child with ID from database - child, err := s.GetStatusByID(ctx, id) + // Further, recursively get all children for this reply. + grandChildren, err := s.GetStatusChildren(ctx, status.ID) if err != nil { - log.Errorf(ctx, "error getting child status %q: %v", id, err) - continue - } - - insertLoop: - for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - log.Panic(ctx, "found status could not be asserted to *gtsmodel.Status") - } - - if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { - foundStatuses.InsertAfter(child, e) - break insertLoop - } + return nil, err } - // if we're not only looking for direct children of status, then do the same children-finding - // operation for the found child status too. - if !onlyDirect { - s.statusChildren(ctx, child, foundStatuses, false, minID) - } + // Append all sub children after status. + children = append(children, grandChildren...) } + + return children, nil } func (s *statusDB) GetStatusReplies(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) { diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go index a69608796..c0ff6c0da 100644 --- a/internal/db/bundb/status_test.go +++ b/internal/db/bundb/status_test.go @@ -163,9 +163,21 @@ func (suite *StatusTestSuite) TestGetStatusTwice() { suite.Less(duration2, duration1) } +func (suite *StatusTestSuite) TestGetStatusReplies() { + targetStatus := suite.testStatuses["local_account_1_status_1"] + children, err := suite.db.GetStatusReplies(context.Background(), targetStatus.ID) + suite.NoError(err) + suite.Len(children, 2) + for _, c := range children { + suite.Equal(targetStatus.URI, c.InReplyToURI) + suite.Equal(targetStatus.AccountID, c.InReplyToAccountID) + suite.Equal(targetStatus.ID, c.InReplyToID) + } +} + func (suite *StatusTestSuite) TestGetStatusChildren() { targetStatus := suite.testStatuses["local_account_1_status_1"] - children, err := suite.db.GetStatusChildren(context.Background(), targetStatus, true, "") + children, err := suite.db.GetStatusChildren(context.Background(), targetStatus.ID) suite.NoError(err) suite.Len(children, 2) for _, c := range children { diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index 1d820d081..a2bc87b88 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -18,6 +18,7 @@ package bundb import ( + "slices" "strings" "github.com/superseriousbusiness/gotosocial/internal/cache" @@ -99,7 +100,7 @@ func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page // order. Depending on the paging requested // this may be an unexpected order. if page.GetOrder().Ascending() { - ids = paging.Reverse(ids) + slices.Reverse(ids) } // Page the resulting IDs. diff --git a/internal/db/status.go b/internal/db/status.go index 0be37421a..1ebf503a8 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -55,7 +55,7 @@ type Status interface { // GetStatusesUsingEmoji fetches all status models using emoji with given ID stored in their 'emojis' column. GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Status, error) - // GetStatusReplies returns the *direct* (i.e. in_reply_to_id column) replies to this status ID. + // GetStatusReplies returns the *direct* (i.e. in_reply_to_id column) replies to this status ID, ordered DESC by ID. GetStatusReplies(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) // CountStatusReplies returns the number of stored *direct* (i.e. in_reply_to_id column) replies to this status ID. @@ -71,14 +71,10 @@ type Status interface { IsStatusBoostedBy(ctx context.Context, statusID string, accountID string) (bool, error) // GetStatusParents gets the parent statuses of a given status. - // - // If onlyDirect is true, only the immediate parent will be returned. - GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) + GetStatusParents(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) // GetStatusChildren gets the child statuses of a given status. - // - // If onlyDirect is true, only the immediate children will be returned. - GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) + GetStatusChildren(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) // IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) |