summaryrefslogtreecommitdiff
path: root/internal/db/bundb/filter.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/filter.go')
-rw-r--r--internal/db/bundb/filter.go263
1 files changed, 67 insertions, 196 deletions
diff --git a/internal/db/bundb/filter.go b/internal/db/bundb/filter.go
index 24208b1f3..dbc560a12 100644
--- a/internal/db/bundb/filter.go
+++ b/internal/db/bundb/filter.go
@@ -21,8 +21,8 @@ import (
"context"
"errors"
"slices"
- "time"
+ "code.superseriousbusiness.org/gotosocial/internal/db"
"code.superseriousbusiness.org/gotosocial/internal/gtscontext"
"code.superseriousbusiness.org/gotosocial/internal/gtserror"
"code.superseriousbusiness.org/gotosocial/internal/gtsmodel"
@@ -64,24 +64,14 @@ func (f *filterDB) GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filt
return filter, nil
}
-func (f *filterDB) GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) {
- // Fetch IDs of all filters owned by this account.
- var filterIDs []string
- if err := f.db.
- NewSelect().
- Model((*gtsmodel.Filter)(nil)).
- Column("id").
- Where("? = ?", bun.Ident("account_id"), accountID).
- Scan(ctx, &filterIDs); err != nil {
- return nil, err
- }
- if len(filterIDs) == 0 {
+func (f *filterDB) GetFiltersByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Filter, error) {
+ if len(ids) == 0 {
return nil, nil
}
// Get each filter by ID from the cache or DB.
filters, err := f.state.Caches.DB.Filter.LoadIDs("ID",
- filterIDs,
+ ids,
func(uncached []string) ([]*gtsmodel.Filter, error) {
filters := make([]*gtsmodel.Filter, 0, len(uncached))
if err := f.db.
@@ -99,14 +89,15 @@ func (f *filterDB) GetFiltersForAccountID(ctx context.Context, accountID string)
}
// Put the filter structs in the same order as the filter IDs.
- xslices.OrderBy(filters, filterIDs, func(filter *gtsmodel.Filter) string { return filter.ID })
+ xslices.OrderBy(filters, ids, func(filter *gtsmodel.Filter) string { return filter.ID })
if gtscontext.Barebones(ctx) {
return filters, nil
}
+ var errs gtserror.MultiError
+
// Populate the filters. Remove any that we can't populate from the return slice.
- errs := gtserror.NewMultiError(len(filters))
filters = slices.DeleteFunc(filters, func(filter *gtsmodel.Filter) bool {
if err := f.populateFilter(ctx, filter); err != nil {
errs.Appendf("error populating filter %s: %w", filter.ID, err)
@@ -118,235 +109,115 @@ func (f *filterDB) GetFiltersForAccountID(ctx context.Context, accountID string)
return filters, errs.Combine()
}
+func (f *filterDB) GetFilterIDsByAccountID(ctx context.Context, accountID string) ([]string, error) {
+ return f.state.Caches.DB.FilterIDs.Load(accountID, func() ([]string, error) {
+ var filterIDs []string
+
+ if err := f.db.
+ NewSelect().
+ Model((*gtsmodel.Filter)(nil)).
+ Column("id").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Scan(ctx, &filterIDs); err != nil {
+ return nil, err
+ }
+
+ return filterIDs, nil
+ })
+}
+
+func (f *filterDB) GetFiltersByAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) {
+ filterIDs, err := f.GetFilterIDsByAccountID(ctx, accountID)
+ if err != nil {
+ return nil, gtserror.Newf("error getting filter ids: %w", err)
+ }
+ return f.GetFiltersByIDs(ctx, filterIDs)
+}
+
func (f *filterDB) populateFilter(ctx context.Context, filter *gtsmodel.Filter) error {
var err error
- errs := gtserror.NewMultiError(2)
+ var errs gtserror.MultiError
- if filter.Keywords == nil {
+ if !filter.KeywordsPopulated() {
// Filter keywords are not set, fetch from the database.
- filter.Keywords, err = f.state.DB.GetFilterKeywordsForFilterID(
- gtscontext.SetBarebones(ctx),
- filter.ID,
- )
+ filter.Keywords, err = f.GetFilterKeywordsByIDs(ctx, filter.KeywordIDs)
if err != nil {
errs.Appendf("error populating filter keywords: %w", err)
}
- for i := range filter.Keywords {
- filter.Keywords[i].Filter = filter
- }
}
- if filter.Statuses == nil {
+ if !filter.StatusesPopulated() {
// Filter statuses are not set, fetch from the database.
- filter.Statuses, err = f.state.DB.GetFilterStatusesForFilterID(
- gtscontext.SetBarebones(ctx),
- filter.ID,
- )
+ filter.Statuses, err = f.GetFilterStatusesByIDs(ctx, filter.StatusIDs)
if err != nil {
errs.Appendf("error populating filter statuses: %w", err)
}
- for i := range filter.Statuses {
- filter.Statuses[i].Filter = filter
- }
}
return errs.Combine()
}
func (f *filterDB) PutFilter(ctx context.Context, filter *gtsmodel.Filter) error {
- // Pre-compile filter keyword regular expressions.
- for _, filterKeyword := range filter.Keywords {
- if err := filterKeyword.Compile(); err != nil {
- return gtserror.Newf("error compiling filter keyword regex: %w", err)
- }
- }
-
- // Update database.
- if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
- if _, err := tx.
- NewInsert().
- Model(filter).
- Exec(ctx); err != nil {
- return err
- }
-
- if len(filter.Keywords) > 0 {
- if _, err := tx.
- NewInsert().
- Model(&filter.Keywords).
- Exec(ctx); err != nil {
- return err
- }
- }
-
- if len(filter.Statuses) > 0 {
- if _, err := tx.
- NewInsert().
- Model(&filter.Statuses).
- Exec(ctx); err != nil {
- return err
- }
- }
-
- return nil
- }); err != nil {
+ return f.state.Caches.DB.Filter.Store(filter, func() error {
+ _, err := f.db.NewInsert().Model(filter).Exec(ctx)
return err
- }
-
- // Update cache.
- f.state.Caches.DB.Filter.Put(filter)
- f.state.Caches.DB.FilterKeyword.Put(filter.Keywords...)
- f.state.Caches.DB.FilterStatus.Put(filter.Statuses...)
-
- return nil
+ })
}
-func (f *filterDB) UpdateFilter(
- ctx context.Context,
- filter *gtsmodel.Filter,
- filterColumns []string,
- filterKeywordColumns [][]string,
- deleteFilterKeywordIDs []string,
- deleteFilterStatusIDs []string,
-) error {
- if len(filter.Keywords) != len(filterKeywordColumns) {
- return errors.New("number of filter keywords must match number of lists of filter keyword columns")
- }
-
- updatedAt := time.Now()
- filter.UpdatedAt = updatedAt
- for _, filterKeyword := range filter.Keywords {
- filterKeyword.UpdatedAt = updatedAt
- }
- for _, filterStatus := range filter.Statuses {
- filterStatus.UpdatedAt = updatedAt
- }
-
- // If we're updating by column, ensure "updated_at" is included.
- if len(filterColumns) > 0 {
- filterColumns = append(filterColumns, "updated_at")
- }
- for i := range filterKeywordColumns {
- if len(filterKeywordColumns[i]) > 0 {
- filterKeywordColumns[i] = append(filterKeywordColumns[i], "updated_at")
- }
- }
-
- // Pre-compile filter keyword regular expressions.
- for _, filterKeyword := range filter.Keywords {
- if err := filterKeyword.Compile(); err != nil {
- return gtserror.Newf("error compiling filter keyword regex: %w", err)
- }
- }
-
- // Update database.
- if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
- if _, err := tx.
- NewUpdate().
+func (f *filterDB) UpdateFilter(ctx context.Context, filter *gtsmodel.Filter, cols ...string) error {
+ return f.state.Caches.DB.Filter.Store(filter, func() error {
+ _, err := f.db.NewUpdate().
Model(filter).
- Column(filterColumns...).
Where("? = ?", bun.Ident("id"), filter.ID).
- Exec(ctx); err != nil {
- return err
- }
-
- for i, filterKeyword := range filter.Keywords {
- if _, err := NewUpsert(tx).
- Model(filterKeyword).
- Constraint("id").
- Column(filterKeywordColumns[i]...).
- Exec(ctx); err != nil {
- return err
- }
- }
-
- if len(filter.Statuses) > 0 {
- if _, err := tx.
- NewInsert().
- Ignore().
- Model(&filter.Statuses).
- Exec(ctx); err != nil {
- return err
- }
- }
-
- if len(deleteFilterKeywordIDs) > 0 {
- if _, err := tx.
- NewDelete().
- Model((*gtsmodel.FilterKeyword)(nil)).
- Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterKeywordIDs)).
- Exec(ctx); err != nil {
- return err
- }
- }
-
- if len(deleteFilterStatusIDs) > 0 {
- if _, err := tx.
- NewDelete().
- Model((*gtsmodel.FilterStatus)(nil)).
- Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterStatusIDs)).
- Exec(ctx); err != nil {
- return err
- }
- }
-
- return nil
- }); err != nil {
+ Column(cols...).
+ Exec(ctx)
return err
- }
-
- // Update cache.
- f.state.Caches.DB.Filter.Put(filter)
- f.state.Caches.DB.FilterKeyword.Put(filter.Keywords...)
- f.state.Caches.DB.FilterStatus.Put(filter.Statuses...)
- // TODO: (Vyr) replace with cache multi-invalidate call
- for _, id := range deleteFilterKeywordIDs {
- f.state.Caches.DB.FilterKeyword.Invalidate("ID", id)
- }
- for _, id := range deleteFilterStatusIDs {
- f.state.Caches.DB.FilterStatus.Invalidate("ID", id)
- }
-
- return nil
+ })
}
-func (f *filterDB) DeleteFilterByID(ctx context.Context, id string) error {
+func (f *filterDB) DeleteFilter(ctx context.Context, filter *gtsmodel.Filter) error {
if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
- // Delete all keywords attached to filter.
+ // Delete all keywords both known
+ // by filter, and possible stragglers,
+ // storing IDs in filter.KeywordIDs.
if _, err := tx.
NewDelete().
Model((*gtsmodel.FilterKeyword)(nil)).
- Where("? = ?", bun.Ident("filter_id"), id).
- Exec(ctx); err != nil {
+ Where("? = ?", bun.Ident("filter_id"), filter.ID).
+ Returning("?", bun.Ident("id")).
+ Exec(ctx, &filter.KeywordIDs); err != nil &&
+ !errors.Is(err, db.ErrNoEntries) {
return err
}
- // Delete all statuses attached to filter.
+ // Delete all statuses both known
+ // by filter, and possible stragglers.
+ // storing IDs in filter.StatusIDs.
if _, err := tx.
NewDelete().
Model((*gtsmodel.FilterStatus)(nil)).
- Where("? = ?", bun.Ident("filter_id"), id).
- Exec(ctx); err != nil {
+ Where("? = ?", bun.Ident("filter_id"), filter.ID).
+ Returning("?", bun.Ident("id")).
+ Exec(ctx, &filter.StatusIDs); err != nil &&
+ !errors.Is(err, db.ErrNoEntries) {
return err
}
- // Delete the filter itself.
+ // Delete filter itself.
_, err := tx.
NewDelete().
Model((*gtsmodel.Filter)(nil)).
- Where("? = ?", bun.Ident("id"), id).
+ Where("? = ?", bun.Ident("id"), filter.ID).
Exec(ctx)
return err
}); err != nil {
return err
}
- // Invalidate this filter.
- f.state.Caches.DB.Filter.Invalidate("ID", id)
-
- // Invalidate all keywords and statuses for this filter.
- f.state.Caches.DB.FilterKeyword.Invalidate("FilterID", id)
- f.state.Caches.DB.FilterStatus.Invalidate("FilterID", id)
+ // Invalidate the filter itself, and
+ // call invalidate hook in-case not cached.
+ f.state.Caches.DB.Filter.Invalidate("ID", filter.ID)
+ f.state.Caches.OnInvalidateFilter(filter)
return nil
}