diff options
Diffstat (limited to 'internal/db/bundb/emoji.go')
-rw-r--r-- | internal/db/bundb/emoji.go | 119 |
1 files changed, 114 insertions, 5 deletions
diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 55bc71e1e..758da0feb 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -20,27 +20,136 @@ package bundb import ( "context" + "strings" + "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/uptrace/bun" ) type emojiDB struct { - conn *DBConn + conn *DBConn + cache *cache.EmojiCache } -func (e emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) { - emojis := []*gtsmodel.Emoji{} +func (e *emojiDB) newEmojiQ(emoji *gtsmodel.Emoji) *bun.SelectQuery { + return e.conn. + NewSelect(). + Model(emoji) +} + +func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error { + if _, err := e.conn.NewInsert().Model(emoji).Exec(ctx); err != nil { + return e.conn.ProcessError(err) + } + + e.cache.Put(emoji) + return nil +} + +func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) { + emojiIDs := []string{} q := e.conn. NewSelect(). - Model(&emojis). + Table("emojis"). + Column("id"). Where("visible_in_picker = true"). Where("disabled = false"). + Where("domain IS NULL"). Order("shortcode ASC") - if err := q.Scan(ctx); err != nil { + if err := q.Scan(ctx, &emojiIDs); err != nil { return nil, e.conn.ProcessError(err) } + + return e.emojisFromIDs(ctx, emojiIDs) +} + +func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) { + return e.getEmoji( + ctx, + func() (*gtsmodel.Emoji, bool) { + return e.cache.GetByID(id) + }, + func(emoji *gtsmodel.Emoji) error { + return e.newEmojiQ(emoji).Where("emoji.id = ?", id).Scan(ctx) + }, + ) +} + +func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) { + return e.getEmoji( + ctx, + func() (*gtsmodel.Emoji, bool) { + return e.cache.GetByURI(uri) + }, + func(emoji *gtsmodel.Emoji) error { + return e.newEmojiQ(emoji).Where("emoji.uri = ?", uri).Scan(ctx) + }, + ) +} + +func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) { + return e.getEmoji( + ctx, + func() (*gtsmodel.Emoji, bool) { + return e.cache.GetByShortcodeDomain(shortcode, domain) + }, + func(emoji *gtsmodel.Emoji) error { + q := e.newEmojiQ(emoji) + + if domain != "" { + q = q.Where("emoji.shortcode = ?", shortcode) + q = q.Where("emoji.domain = ?", domain) + } else { + q = q.Where("emoji.shortcode = ?", strings.ToLower(shortcode)) + q = q.Where("emoji.domain IS NULL") + } + + return q.Scan(ctx) + }, + ) +} + +func (e *emojiDB) getEmoji(ctx context.Context, cacheGet func() (*gtsmodel.Emoji, bool), dbQuery func(*gtsmodel.Emoji) error) (*gtsmodel.Emoji, db.Error) { + // Attempt to fetch cached emoji + emoji, cached := cacheGet() + + if !cached { + emoji = >smodel.Emoji{} + + // Not cached! Perform database query + err := dbQuery(emoji) + if err != nil { + return nil, e.conn.ProcessError(err) + } + + // Place in the cache + e.cache.Put(emoji) + } + + return emoji, nil +} + +func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) { + // Catch case of no emojis early + if len(emojiIDs) == 0 { + return nil, db.ErrNoEntries + } + + emojis := make([]*gtsmodel.Emoji, 0, len(emojiIDs)) + + for _, id := range emojiIDs { + emoji, err := e.GetEmojiByID(ctx, id) + if err != nil { + log.Errorf("emojisFromIDs: error getting emoji %q: %v", id, err) + } + + emojis = append(emojis, emoji) + } + return emojis, nil } |