diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/emoji.go | 120 | ||||
-rw-r--r-- | internal/db/bundb/emoji_test.go | 92 | ||||
-rw-r--r-- | internal/db/emoji.go | 10 |
3 files changed, 217 insertions, 5 deletions
diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index e781e2f00..640e354c4 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -27,6 +27,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect" ) type emojiDB struct { @@ -49,7 +50,124 @@ func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error return nil } -func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) { +func (e *emojiDB) GetEmojis(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, db.Error) { + emojiIDs := []string{} + + subQuery := e.conn. + NewSelect(). + ColumnExpr("? AS ?", bun.Ident("emoji.id"), bun.Ident("emoji_ids")) + + // To ensure consistent ordering and make paging possible, we sort not by shortcode + // but by [shortcode]@[domain]. Because sqlite and postgres have different syntax + // for concatenation, that means we need to switch here. Depending on which driver + // is in use, query will look something like this (sqlite): + // + // SELECT + // "emoji"."id" AS "emoji_ids", + // lower("emoji"."shortcode" || '@' || COALESCE("emoji"."domain", '')) AS "shortcode_domain" + // FROM + // "emojis" AS "emoji" + // ORDER BY + // "shortcode_domain" ASC + // + // Or like this (postgres): + // + // SELECT + // "emoji"."id" AS "emoji_ids", + // LOWER(CONCAT("emoji"."shortcode", '@', COALESCE("emoji"."domain", ''))) AS "shortcode_domain" + // FROM + // "emojis" AS "emoji" + // ORDER BY + // "shortcode_domain" ASC + switch e.conn.Dialect().Name() { + case dialect.SQLite: + subQuery = subQuery.ColumnExpr("LOWER(? || ? || COALESCE(?, ?)) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain")) + case dialect.PG: + subQuery = subQuery.ColumnExpr("LOWER(CONCAT(?, ?, COALESCE(?, ?))) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain")) + default: + panic("db conn was neither pg not sqlite") + } + + subQuery = subQuery.TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")) + + if domain == "" { + subQuery = subQuery.Where("? IS NULL", bun.Ident("emoji.domain")) + } else if domain != db.EmojiAllDomains { + subQuery = subQuery.Where("? = ?", bun.Ident("emoji.domain"), domain) + } + + switch { + case includeDisabled && !includeEnabled: + // show only disabled emojis + subQuery = subQuery.Where("? = ?", bun.Ident("emoji.disabled"), true) + case includeEnabled && !includeDisabled: + // show only enabled emojis + subQuery = subQuery.Where("? = ?", bun.Ident("emoji.disabled"), false) + default: + // show emojis regardless of emoji.disabled value + } + + if shortcode != "" { + subQuery = subQuery.Where("LOWER(?) = LOWER(?)", bun.Ident("emoji.shortcode"), shortcode) + } + + // assume we want to sort ASC (a-z) unless informed otherwise + order := "ASC" + + if maxShortcodeDomain != "" { + subQuery = subQuery.Where("? > LOWER(?)", bun.Ident("shortcode_domain"), maxShortcodeDomain) + } + + if minShortcodeDomain != "" { + subQuery = subQuery.Where("? < LOWER(?)", bun.Ident("shortcode_domain"), minShortcodeDomain) + // if we have a minShortcodeDomain we're paging upwards/backwards + order = "DESC" + } + + subQuery = subQuery.Order("shortcode_domain " + order) + + if limit > 0 { + subQuery = subQuery.Limit(limit) + } + + // Wrap the subQuery in a query, since we don't need to select the shortcode_domain column. + // + // The final query will come out looking something like... + // + // SELECT + // "subquery"."emoji_ids" + // FROM ( + // SELECT + // "emoji"."id" AS "emoji_ids", + // LOWER("emoji"."shortcode" || '@' || COALESCE("emoji"."domain", '')) AS "shortcode_domain" + // FROM + // "emojis" AS "emoji" + // ORDER BY + // "shortcode_domain" ASC + // ) AS "subquery" + if err := e.conn. + NewSelect(). + Column("subquery.emoji_ids"). + TableExpr("(?) AS ?", subQuery, bun.Ident("subquery")). + Scan(ctx, &emojiIDs); err != nil { + return nil, e.conn.ProcessError(err) + } + + if order == "DESC" { + // Reverse the slice order so the caller still + // gets emojis in expected a-z alphabetical order. + // + // See https://github.com/golang/go/wiki/SliceTricks#reversing + for i := len(emojiIDs)/2 - 1; i >= 0; i-- { + opp := len(emojiIDs) - 1 - i + emojiIDs[i], emojiIDs[opp] = emojiIDs[opp], emojiIDs[i] + } + } + + return e.emojisFromIDs(ctx, emojiIDs) +} + +func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) { emojiIDs := []string{} q := e.conn. diff --git a/internal/db/bundb/emoji_test.go b/internal/db/bundb/emoji_test.go index 0a1546d91..3c61fb620 100644 --- a/internal/db/bundb/emoji_test.go +++ b/internal/db/bundb/emoji_test.go @@ -23,20 +23,108 @@ import ( "testing" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" ) type EmojiTestSuite struct { BunDBStandardTestSuite } -func (suite *EmojiTestSuite) TestGetCustomEmojis() { - emojis, err := suite.db.GetCustomEmojis(context.Background()) +func (suite *EmojiTestSuite) TestGetUseableEmojis() { + emojis, err := suite.db.GetUseableEmojis(context.Background()) suite.NoError(err) suite.Equal(1, len(emojis)) suite.Equal("rainbow", emojis[0].Shortcode) } +func (suite *EmojiTestSuite) TestGetAllEmojis() { + emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "", "", 0) + + suite.NoError(err) + suite.Equal(2, len(emojis)) + suite.Equal("rainbow", emojis[0].Shortcode) + suite.Equal("yell", emojis[1].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetAllEmojisLimit1() { + emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "", "", 1) + + suite.NoError(err) + suite.Equal(1, len(emojis)) + suite.Equal("rainbow", emojis[0].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetAllEmojisMaxID() { + emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "rainbow@", "", 0) + + suite.NoError(err) + suite.Equal(1, len(emojis)) + suite.Equal("yell", emojis[0].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetAllEmojisMinID() { + emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "", "yell@fossbros-anonymous.io", 0) + + suite.NoError(err) + suite.Equal(1, len(emojis)) + suite.Equal("rainbow", emojis[0].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetAllDisabledEmojis() { + emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, false, "", "", "", 0) + + suite.ErrorIs(err, db.ErrNoEntries) + suite.Equal(0, len(emojis)) +} + +func (suite *EmojiTestSuite) TestGetAllEnabledEmojis() { + emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, false, true, "", "", "", 0) + + suite.NoError(err) + suite.Equal(2, len(emojis)) + suite.Equal("rainbow", emojis[0].Shortcode) + suite.Equal("yell", emojis[1].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetLocalEnabledEmojis() { + emojis, err := suite.db.GetEmojis(context.Background(), "", false, true, "", "", "", 0) + + suite.NoError(err) + suite.Equal(1, len(emojis)) + suite.Equal("rainbow", emojis[0].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetLocalDisabledEmojis() { + emojis, err := suite.db.GetEmojis(context.Background(), "", true, false, "", "", "", 0) + + suite.ErrorIs(err, db.ErrNoEntries) + suite.Equal(0, len(emojis)) +} + +func (suite *EmojiTestSuite) TestGetAllEmojisFromDomain() { + emojis, err := suite.db.GetEmojis(context.Background(), "peepee.poopoo", true, true, "", "", "", 0) + + suite.ErrorIs(err, db.ErrNoEntries) + suite.Equal(0, len(emojis)) +} + +func (suite *EmojiTestSuite) TestGetAllEmojisFromDomain2() { + emojis, err := suite.db.GetEmojis(context.Background(), "fossbros-anonymous.io", true, true, "", "", "", 0) + + suite.NoError(err) + suite.Equal(1, len(emojis)) + suite.Equal("yell", emojis[0].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetSpecificEmojisFromDomain2() { + emojis, err := suite.db.GetEmojis(context.Background(), "fossbros-anonymous.io", true, true, "yell", "", "", 0) + + suite.NoError(err) + suite.Equal(1, len(emojis)) + suite.Equal("yell", emojis[0].Shortcode) +} + func TestEmojiTestSuite(t *testing.T) { suite.Run(t, new(EmojiTestSuite)) } diff --git a/internal/db/emoji.go b/internal/db/emoji.go index 374fd7b12..4316a43ef 100644 --- a/internal/db/emoji.go +++ b/internal/db/emoji.go @@ -24,12 +24,18 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) +// EmojiAllDomains can be used as the `domain` value in a GetEmojis +// query to indicate that emojis from all domains should be returned. +const EmojiAllDomains string = "all" + // Emoji contains functions for getting emoji in the database. type Emoji interface { // PutEmoji puts one emoji in the database. PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) Error - // GetCustomEmojis gets all custom emoji for the instance - GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, Error) + // GetUseableEmojis gets all emojis which are useable by accounts on this instance. + GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, Error) + // GetEmojis gets emojis based on given parameters. Useful for admin actions. + GetEmojis(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, Error) // GetEmojiByID gets a specific emoji by its database ID. GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, Error) // GetEmojiByShortcodeDomain gets an emoji based on its shortcode and domain. |