summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/emoji.go120
-rw-r--r--internal/db/bundb/emoji_test.go92
-rw-r--r--internal/db/emoji.go10
3 files changed, 217 insertions, 5 deletions
diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go
index e781e2f00..640e354c4 100644
--- a/internal/db/bundb/emoji.go
+++ b/internal/db/bundb/emoji.go
@@ -27,6 +27,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
+ "github.com/uptrace/bun/dialect"
)
type emojiDB struct {
@@ -49,7 +50,124 @@ func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error
return nil
}
-func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) {
+func (e *emojiDB) GetEmojis(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, db.Error) {
+ emojiIDs := []string{}
+
+ subQuery := e.conn.
+ NewSelect().
+ ColumnExpr("? AS ?", bun.Ident("emoji.id"), bun.Ident("emoji_ids"))
+
+ // To ensure consistent ordering and make paging possible, we sort not by shortcode
+ // but by [shortcode]@[domain]. Because sqlite and postgres have different syntax
+ // for concatenation, that means we need to switch here. Depending on which driver
+ // is in use, query will look something like this (sqlite):
+ //
+ // SELECT
+ // "emoji"."id" AS "emoji_ids",
+ // lower("emoji"."shortcode" || '@' || COALESCE("emoji"."domain", '')) AS "shortcode_domain"
+ // FROM
+ // "emojis" AS "emoji"
+ // ORDER BY
+ // "shortcode_domain" ASC
+ //
+ // Or like this (postgres):
+ //
+ // SELECT
+ // "emoji"."id" AS "emoji_ids",
+ // LOWER(CONCAT("emoji"."shortcode", '@', COALESCE("emoji"."domain", ''))) AS "shortcode_domain"
+ // FROM
+ // "emojis" AS "emoji"
+ // ORDER BY
+ // "shortcode_domain" ASC
+ switch e.conn.Dialect().Name() {
+ case dialect.SQLite:
+ subQuery = subQuery.ColumnExpr("LOWER(? || ? || COALESCE(?, ?)) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain"))
+ case dialect.PG:
+ subQuery = subQuery.ColumnExpr("LOWER(CONCAT(?, ?, COALESCE(?, ?))) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain"))
+ default:
+ panic("db conn was neither pg not sqlite")
+ }
+
+ subQuery = subQuery.TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji"))
+
+ if domain == "" {
+ subQuery = subQuery.Where("? IS NULL", bun.Ident("emoji.domain"))
+ } else if domain != db.EmojiAllDomains {
+ subQuery = subQuery.Where("? = ?", bun.Ident("emoji.domain"), domain)
+ }
+
+ switch {
+ case includeDisabled && !includeEnabled:
+ // show only disabled emojis
+ subQuery = subQuery.Where("? = ?", bun.Ident("emoji.disabled"), true)
+ case includeEnabled && !includeDisabled:
+ // show only enabled emojis
+ subQuery = subQuery.Where("? = ?", bun.Ident("emoji.disabled"), false)
+ default:
+ // show emojis regardless of emoji.disabled value
+ }
+
+ if shortcode != "" {
+ subQuery = subQuery.Where("LOWER(?) = LOWER(?)", bun.Ident("emoji.shortcode"), shortcode)
+ }
+
+ // assume we want to sort ASC (a-z) unless informed otherwise
+ order := "ASC"
+
+ if maxShortcodeDomain != "" {
+ subQuery = subQuery.Where("? > LOWER(?)", bun.Ident("shortcode_domain"), maxShortcodeDomain)
+ }
+
+ if minShortcodeDomain != "" {
+ subQuery = subQuery.Where("? < LOWER(?)", bun.Ident("shortcode_domain"), minShortcodeDomain)
+ // if we have a minShortcodeDomain we're paging upwards/backwards
+ order = "DESC"
+ }
+
+ subQuery = subQuery.Order("shortcode_domain " + order)
+
+ if limit > 0 {
+ subQuery = subQuery.Limit(limit)
+ }
+
+ // Wrap the subQuery in a query, since we don't need to select the shortcode_domain column.
+ //
+ // The final query will come out looking something like...
+ //
+ // SELECT
+ // "subquery"."emoji_ids"
+ // FROM (
+ // SELECT
+ // "emoji"."id" AS "emoji_ids",
+ // LOWER("emoji"."shortcode" || '@' || COALESCE("emoji"."domain", '')) AS "shortcode_domain"
+ // FROM
+ // "emojis" AS "emoji"
+ // ORDER BY
+ // "shortcode_domain" ASC
+ // ) AS "subquery"
+ if err := e.conn.
+ NewSelect().
+ Column("subquery.emoji_ids").
+ TableExpr("(?) AS ?", subQuery, bun.Ident("subquery")).
+ Scan(ctx, &emojiIDs); err != nil {
+ return nil, e.conn.ProcessError(err)
+ }
+
+ if order == "DESC" {
+ // Reverse the slice order so the caller still
+ // gets emojis in expected a-z alphabetical order.
+ //
+ // See https://github.com/golang/go/wiki/SliceTricks#reversing
+ for i := len(emojiIDs)/2 - 1; i >= 0; i-- {
+ opp := len(emojiIDs) - 1 - i
+ emojiIDs[i], emojiIDs[opp] = emojiIDs[opp], emojiIDs[i]
+ }
+ }
+
+ return e.emojisFromIDs(ctx, emojiIDs)
+}
+
+func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) {
emojiIDs := []string{}
q := e.conn.
diff --git a/internal/db/bundb/emoji_test.go b/internal/db/bundb/emoji_test.go
index 0a1546d91..3c61fb620 100644
--- a/internal/db/bundb/emoji_test.go
+++ b/internal/db/bundb/emoji_test.go
@@ -23,20 +23,108 @@ import (
"testing"
"github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
)
type EmojiTestSuite struct {
BunDBStandardTestSuite
}
-func (suite *EmojiTestSuite) TestGetCustomEmojis() {
- emojis, err := suite.db.GetCustomEmojis(context.Background())
+func (suite *EmojiTestSuite) TestGetUseableEmojis() {
+ emojis, err := suite.db.GetUseableEmojis(context.Background())
suite.NoError(err)
suite.Equal(1, len(emojis))
suite.Equal("rainbow", emojis[0].Shortcode)
}
+func (suite *EmojiTestSuite) TestGetAllEmojis() {
+ emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "", "", 0)
+
+ suite.NoError(err)
+ suite.Equal(2, len(emojis))
+ suite.Equal("rainbow", emojis[0].Shortcode)
+ suite.Equal("yell", emojis[1].Shortcode)
+}
+
+func (suite *EmojiTestSuite) TestGetAllEmojisLimit1() {
+ emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "", "", 1)
+
+ suite.NoError(err)
+ suite.Equal(1, len(emojis))
+ suite.Equal("rainbow", emojis[0].Shortcode)
+}
+
+func (suite *EmojiTestSuite) TestGetAllEmojisMaxID() {
+ emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "rainbow@", "", 0)
+
+ suite.NoError(err)
+ suite.Equal(1, len(emojis))
+ suite.Equal("yell", emojis[0].Shortcode)
+}
+
+func (suite *EmojiTestSuite) TestGetAllEmojisMinID() {
+ emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "", "yell@fossbros-anonymous.io", 0)
+
+ suite.NoError(err)
+ suite.Equal(1, len(emojis))
+ suite.Equal("rainbow", emojis[0].Shortcode)
+}
+
+func (suite *EmojiTestSuite) TestGetAllDisabledEmojis() {
+ emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, false, "", "", "", 0)
+
+ suite.ErrorIs(err, db.ErrNoEntries)
+ suite.Equal(0, len(emojis))
+}
+
+func (suite *EmojiTestSuite) TestGetAllEnabledEmojis() {
+ emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, false, true, "", "", "", 0)
+
+ suite.NoError(err)
+ suite.Equal(2, len(emojis))
+ suite.Equal("rainbow", emojis[0].Shortcode)
+ suite.Equal("yell", emojis[1].Shortcode)
+}
+
+func (suite *EmojiTestSuite) TestGetLocalEnabledEmojis() {
+ emojis, err := suite.db.GetEmojis(context.Background(), "", false, true, "", "", "", 0)
+
+ suite.NoError(err)
+ suite.Equal(1, len(emojis))
+ suite.Equal("rainbow", emojis[0].Shortcode)
+}
+
+func (suite *EmojiTestSuite) TestGetLocalDisabledEmojis() {
+ emojis, err := suite.db.GetEmojis(context.Background(), "", true, false, "", "", "", 0)
+
+ suite.ErrorIs(err, db.ErrNoEntries)
+ suite.Equal(0, len(emojis))
+}
+
+func (suite *EmojiTestSuite) TestGetAllEmojisFromDomain() {
+ emojis, err := suite.db.GetEmojis(context.Background(), "peepee.poopoo", true, true, "", "", "", 0)
+
+ suite.ErrorIs(err, db.ErrNoEntries)
+ suite.Equal(0, len(emojis))
+}
+
+func (suite *EmojiTestSuite) TestGetAllEmojisFromDomain2() {
+ emojis, err := suite.db.GetEmojis(context.Background(), "fossbros-anonymous.io", true, true, "", "", "", 0)
+
+ suite.NoError(err)
+ suite.Equal(1, len(emojis))
+ suite.Equal("yell", emojis[0].Shortcode)
+}
+
+func (suite *EmojiTestSuite) TestGetSpecificEmojisFromDomain2() {
+ emojis, err := suite.db.GetEmojis(context.Background(), "fossbros-anonymous.io", true, true, "yell", "", "", 0)
+
+ suite.NoError(err)
+ suite.Equal(1, len(emojis))
+ suite.Equal("yell", emojis[0].Shortcode)
+}
+
func TestEmojiTestSuite(t *testing.T) {
suite.Run(t, new(EmojiTestSuite))
}
diff --git a/internal/db/emoji.go b/internal/db/emoji.go
index 374fd7b12..4316a43ef 100644
--- a/internal/db/emoji.go
+++ b/internal/db/emoji.go
@@ -24,12 +24,18 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
+// EmojiAllDomains can be used as the `domain` value in a GetEmojis
+// query to indicate that emojis from all domains should be returned.
+const EmojiAllDomains string = "all"
+
// Emoji contains functions for getting emoji in the database.
type Emoji interface {
// PutEmoji puts one emoji in the database.
PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) Error
- // GetCustomEmojis gets all custom emoji for the instance
- GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, Error)
+ // GetUseableEmojis gets all emojis which are useable by accounts on this instance.
+ GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, Error)
+ // GetEmojis gets emojis based on given parameters. Useful for admin actions.
+ GetEmojis(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, Error)
// GetEmojiByID gets a specific emoji by its database ID.
GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, Error)
// GetEmojiByShortcodeDomain gets an emoji based on its shortcode and domain.