summaryrefslogtreecommitdiff
path: root/internal/db/bundb/tag.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/tag.go')
-rw-r--r--internal/db/bundb/tag.go75
1 files changed, 51 insertions, 24 deletions
diff --git a/internal/db/bundb/tag.go b/internal/db/bundb/tag.go
index fac621f0a..66ee8cb3a 100644
--- a/internal/db/bundb/tag.go
+++ b/internal/db/bundb/tag.go
@@ -22,21 +22,21 @@ import (
"strings"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
type tagDB struct {
- conn *DB
+ db *DB
state *state.State
}
-func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {
- return m.state.Caches.GTS.Tag().Load("ID", func() (*gtsmodel.Tag, error) {
+func (t *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {
+ return t.state.Caches.GTS.Tag.LoadOne("ID", func() (*gtsmodel.Tag, error) {
var tag gtsmodel.Tag
- q := m.conn.
+ q := t.db.
NewSelect().
Model(&tag).
Where("? = ?", bun.Ident("tag.id"), id)
@@ -49,15 +49,15 @@ func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {
}, id)
}
-func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) {
+func (t *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) {
// Normalize 'name' string.
name = strings.TrimSpace(name)
name = strings.ToLower(name)
- return m.state.Caches.GTS.Tag().Load("Name", func() (*gtsmodel.Tag, error) {
+ return t.state.Caches.GTS.Tag.LoadOne("Name", func() (*gtsmodel.Tag, error) {
var tag gtsmodel.Tag
- q := m.conn.
+ q := t.db.
NewSelect().
Model(&tag).
Where("? = ?", bun.Ident("tag.name"), name)
@@ -70,25 +70,52 @@ func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, e
}, name)
}
-func (m *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) {
- tags := make([]*gtsmodel.Tag, 0, len(ids))
-
- for _, id := range ids {
- // Attempt fetch from DB
- tag, err := m.GetTag(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting tag %q: %v", id, err)
- continue
- }
-
- // Append tag
- tags = append(tags, tag)
+func (t *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) {
+ // Preallocate at-worst possible length.
+ uncached := make([]string, 0, len(ids))
+
+ // Load all tag IDs via cache loader callbacks.
+ tags, err := t.state.Caches.GTS.Tag.Load("ID",
+
+ // Load cached + check for uncached.
+ func(load func(keyParts ...any) bool) {
+ for _, id := range ids {
+ if !load(id) {
+ uncached = append(uncached, id)
+ }
+ }
+ },
+
+ // Uncached tag loader function.
+ func() ([]*gtsmodel.Tag, error) {
+ // Preallocate expected length of uncached tags.
+ tags := make([]*gtsmodel.Tag, 0, len(uncached))
+
+ // Perform database query scanning
+ // the remaining (uncached) IDs.
+ if err := t.db.NewSelect().
+ Model(&tags).
+ Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return tags, nil
+ },
+ )
+ if err != nil {
+ return nil, err
}
+ // Reorder the tags by their
+ // IDs to ensure in correct order.
+ getID := func(t *gtsmodel.Tag) string { return t.ID }
+ util.OrderBy(tags, ids, getID)
+
return tags, nil
}
-func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {
+func (t *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {
// Normalize 'name' string before it enters
// the db, without changing tag we were given.
//
@@ -101,8 +128,8 @@ func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {
t2.Name = strings.ToLower(t2.Name)
// Insert the copy.
- if err := m.state.Caches.GTS.Tag().Store(t2, func() error {
- _, err := m.conn.NewInsert().Model(t2).Exec(ctx)
+ if err := t.state.Caches.GTS.Tag.Store(t2, func() error {
+ _, err := t.db.NewInsert().Model(t2).Exec(ctx)
return err
}); err != nil {
return err // err already processed