diff options
Diffstat (limited to 'internal/db/bundb/tag.go')
-rw-r--r-- | internal/db/bundb/tag.go | 159 |
1 files changed, 159 insertions, 0 deletions
diff --git a/internal/db/bundb/tag.go b/internal/db/bundb/tag.go index 5218a19d5..c6298ee64 100644 --- a/internal/db/bundb/tag.go +++ b/internal/db/bundb/tag.go @@ -19,9 +19,13 @@ package bundb import ( "context" + "errors" "strings" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" @@ -131,3 +135,158 @@ func (t *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error { return nil } + +func (t *tagDB) GetFollowedTags(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Tag, error) { + tagIDs, err := t.getTagIDsFollowedByAccount(ctx, accountID, page) + if err != nil { + return nil, err + } + + tags, err := t.GetTags(ctx, tagIDs) + if err != nil { + return nil, err + } + + return tags, nil +} + +func (t *tagDB) getTagIDsFollowedByAccount(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { + return loadPagedIDs(&t.state.Caches.DB.TagIDsFollowedByAccount, accountID, page, func() ([]string, error) { + var tagIDs []string + + // Tag IDs not in cache. Perform DB query. + if _, err := t.db. + NewSelect(). + Model((*gtsmodel.FollowedTag)(nil)). + Column("tag_id"). + Where("? = ?", bun.Ident("account_id"), accountID). + OrderExpr("? DESC", bun.Ident("tag_id")). + Exec(ctx, &tagIDs); // nocollapse + err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.Newf("error getting tag IDs followed by account %s: %w", accountID, err) + } + + return tagIDs, nil + }) +} + +func (t *tagDB) getAccountIDsFollowingTag(ctx context.Context, tagID string) ([]string, error) { + return loadPagedIDs(&t.state.Caches.DB.AccountIDsFollowingTag, tagID, nil, func() ([]string, error) { + var accountIDs []string + + // Account IDs not in cache. Perform DB query. + if _, err := t.db. + NewSelect(). + Model((*gtsmodel.FollowedTag)(nil)). + Column("account_id"). + Where("? = ?", bun.Ident("tag_id"), tagID). + OrderExpr("? DESC", bun.Ident("account_id")). + Exec(ctx, &accountIDs); // nocollapse + err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.Newf("error getting account IDs following tag %s: %w", tagID, err) + } + + return accountIDs, nil + }) +} + +func (t *tagDB) IsAccountFollowingTag(ctx context.Context, accountID string, tagID string) (bool, error) { + accountTagIDs, err := t.getTagIDsFollowedByAccount(ctx, accountID, nil) + if err != nil { + return false, err + } + + for _, accountTagID := range accountTagIDs { + if accountTagID == tagID { + return true, nil + } + } + + return false, nil +} + +func (t *tagDB) PutFollowedTag(ctx context.Context, accountID string, tagID string) error { + // Insert the followed tag. + result, err := t.db.NewInsert(). + Model(>smodel.FollowedTag{ + AccountID: accountID, + TagID: tagID, + }). + On("CONFLICT (?, ?) DO NOTHING", bun.Ident("account_id"), bun.Ident("tag_id")). + Exec(ctx) + if err != nil { + return gtserror.Newf("error inserting followed tag: %w", err) + } + + // If it fails because that account already follows that tag, that's fine, and we're done. + rows, err := result.RowsAffected() + if err != nil { + return gtserror.Newf("error getting inserted row count: %w", err) + } + if rows == 0 { + return nil + } + + // Otherwise, this is a new followed tag, so we invalidate caches related to it. + t.state.Caches.DB.AccountIDsFollowingTag.Invalidate(tagID) + t.state.Caches.DB.TagIDsFollowedByAccount.Invalidate(accountID) + + return nil +} + +func (t *tagDB) DeleteFollowedTag(ctx context.Context, accountID string, tagID string) error { + result, err := t.db.NewDelete(). + Model((*gtsmodel.FollowedTag)(nil)). + Where("? = ?", bun.Ident("account_id"), accountID). + Where("? = ?", bun.Ident("tag_id"), tagID). + Exec(ctx) + if err != nil { + return gtserror.Newf("error deleting followed tag %s for account %s: %w", tagID, accountID, err) + } + + rows, err := result.RowsAffected() + if err != nil { + return gtserror.Newf("error getting inserted row count: %w", err) + } + if rows == 0 { + return nil + } + + // If we deleted anything, invalidate caches related to it. + t.state.Caches.DB.AccountIDsFollowingTag.Invalidate(tagID) + t.state.Caches.DB.TagIDsFollowedByAccount.Invalidate(accountID) + + return err +} + +func (t *tagDB) DeleteFollowedTagsByAccountID(ctx context.Context, accountID string) error { + // Delete followed tags from the database, returning the list of tag IDs affected. + tagIDs := []string{} + if err := t.db.NewDelete(). + Model((*gtsmodel.FollowedTag)(nil)). + Where("? = ?", bun.Ident("account_id"), accountID). + Returning("?", bun.Ident("tag_id")). + Scan(ctx, &tagIDs); // nocollapse + err != nil { + return gtserror.Newf("error deleting followed tags for account %s: %w", accountID, err) + } + + // Invalidate account ID caches for the account and those tags. + t.state.Caches.DB.TagIDsFollowedByAccount.Invalidate(accountID) + t.state.Caches.DB.AccountIDsFollowingTag.Invalidate(tagIDs...) + + return nil +} + +func (t *tagDB) GetAccountIDsFollowingTagIDs(ctx context.Context, tagIDs []string) ([]string, error) { + // Accounts might be following multiple tags in this list, but we only want to return each account once. + accountIDs := []string{} + for _, tagID := range tagIDs { + tagAccountIDs, err := t.getAccountIDsFollowingTag(ctx, tagID) + if err != nil { + return nil, err + } + accountIDs = append(accountIDs, tagAccountIDs...) + } + return util.UniqueStrings(accountIDs), nil +} |