diff options
Diffstat (limited to 'internal/db/bundb/tag.go')
-rw-r--r-- | internal/db/bundb/tag.go | 75 |
1 files changed, 51 insertions, 24 deletions
diff --git a/internal/db/bundb/tag.go b/internal/db/bundb/tag.go index fac621f0a..66ee8cb3a 100644 --- a/internal/db/bundb/tag.go +++ b/internal/db/bundb/tag.go @@ -22,21 +22,21 @@ import ( "strings" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" ) type tagDB struct { - conn *DB + db *DB state *state.State } -func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) { - return m.state.Caches.GTS.Tag().Load("ID", func() (*gtsmodel.Tag, error) { +func (t *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) { + return t.state.Caches.GTS.Tag.LoadOne("ID", func() (*gtsmodel.Tag, error) { var tag gtsmodel.Tag - q := m.conn. + q := t.db. NewSelect(). Model(&tag). Where("? = ?", bun.Ident("tag.id"), id) @@ -49,15 +49,15 @@ func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) { }, id) } -func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) { +func (t *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) { // Normalize 'name' string. name = strings.TrimSpace(name) name = strings.ToLower(name) - return m.state.Caches.GTS.Tag().Load("Name", func() (*gtsmodel.Tag, error) { + return t.state.Caches.GTS.Tag.LoadOne("Name", func() (*gtsmodel.Tag, error) { var tag gtsmodel.Tag - q := m.conn. + q := t.db. NewSelect(). Model(&tag). Where("? = ?", bun.Ident("tag.name"), name) @@ -70,25 +70,52 @@ func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, e }, name) } -func (m *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) { - tags := make([]*gtsmodel.Tag, 0, len(ids)) - - for _, id := range ids { - // Attempt fetch from DB - tag, err := m.GetTag(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting tag %q: %v", id, err) - continue - } - - // Append tag - tags = append(tags, tag) +func (t *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) { + // Preallocate at-worst possible length. + uncached := make([]string, 0, len(ids)) + + // Load all tag IDs via cache loader callbacks. + tags, err := t.state.Caches.GTS.Tag.Load("ID", + + // Load cached + check for uncached. + func(load func(keyParts ...any) bool) { + for _, id := range ids { + if !load(id) { + uncached = append(uncached, id) + } + } + }, + + // Uncached tag loader function. + func() ([]*gtsmodel.Tag, error) { + // Preallocate expected length of uncached tags. + tags := make([]*gtsmodel.Tag, 0, len(uncached)) + + // Perform database query scanning + // the remaining (uncached) IDs. + if err := t.db.NewSelect(). + Model(&tags). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); err != nil { + return nil, err + } + + return tags, nil + }, + ) + if err != nil { + return nil, err } + // Reorder the tags by their + // IDs to ensure in correct order. + getID := func(t *gtsmodel.Tag) string { return t.ID } + util.OrderBy(tags, ids, getID) + return tags, nil } -func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error { +func (t *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error { // Normalize 'name' string before it enters // the db, without changing tag we were given. // @@ -101,8 +128,8 @@ func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error { t2.Name = strings.ToLower(t2.Name) // Insert the copy. - if err := m.state.Caches.GTS.Tag().Store(t2, func() error { - _, err := m.conn.NewInsert().Model(t2).Exec(ctx) + if err := t.state.Caches.GTS.Tag.Store(t2, func() error { + _, err := t.db.NewInsert().Model(t2).Exec(ctx) return err }); err != nil { return err // err already processed |