diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/emoji.go | 120 | ||||
| -rw-r--r-- | internal/db/bundb/emoji_test.go | 92 | 
2 files changed, 209 insertions, 3 deletions
diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index e781e2f00..640e354c4 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -27,6 +27,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/uptrace/bun" +	"github.com/uptrace/bun/dialect"  )  type emojiDB struct { @@ -49,7 +50,124 @@ func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error  	return nil  } -func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) { +func (e *emojiDB) GetEmojis(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, db.Error) { +	emojiIDs := []string{} + +	subQuery := e.conn. +		NewSelect(). +		ColumnExpr("? AS ?", bun.Ident("emoji.id"), bun.Ident("emoji_ids")) + +	// To ensure consistent ordering and make paging possible, we sort not by shortcode +	// but by [shortcode]@[domain]. Because sqlite and postgres have different syntax +	// for concatenation, that means we need to switch here. Depending on which driver +	// is in use, query will look something like this (sqlite): +	// +	//	SELECT +	//		"emoji"."id" AS "emoji_ids", +	//		lower("emoji"."shortcode" || '@' || COALESCE("emoji"."domain", '')) AS "shortcode_domain" +	//	FROM +	//		"emojis" AS "emoji" +	//	ORDER BY +	//		"shortcode_domain" ASC +	// +	// Or like this (postgres): +	// +	//	SELECT +	//		"emoji"."id" AS "emoji_ids", +	//		LOWER(CONCAT("emoji"."shortcode", '@', COALESCE("emoji"."domain", ''))) AS "shortcode_domain" +	//	FROM +	//		"emojis" AS "emoji" +	//	ORDER BY +	//		"shortcode_domain" ASC +	switch e.conn.Dialect().Name() { +	case dialect.SQLite: +		subQuery = subQuery.ColumnExpr("LOWER(? || ? || COALESCE(?, ?)) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain")) +	case dialect.PG: +		subQuery = subQuery.ColumnExpr("LOWER(CONCAT(?, ?, COALESCE(?, ?))) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain")) +	default: +		panic("db conn was neither pg not sqlite") +	} + +	subQuery = subQuery.TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")) + +	if domain == "" { +		subQuery = subQuery.Where("? IS NULL", bun.Ident("emoji.domain")) +	} else if domain != db.EmojiAllDomains { +		subQuery = subQuery.Where("? = ?", bun.Ident("emoji.domain"), domain) +	} + +	switch { +	case includeDisabled && !includeEnabled: +		// show only disabled emojis +		subQuery = subQuery.Where("? = ?", bun.Ident("emoji.disabled"), true) +	case includeEnabled && !includeDisabled: +		// show only enabled emojis +		subQuery = subQuery.Where("? = ?", bun.Ident("emoji.disabled"), false) +	default: +		// show emojis regardless of emoji.disabled value +	} + +	if shortcode != "" { +		subQuery = subQuery.Where("LOWER(?) = LOWER(?)", bun.Ident("emoji.shortcode"), shortcode) +	} + +	// assume we want to sort ASC (a-z) unless informed otherwise +	order := "ASC" + +	if maxShortcodeDomain != "" { +		subQuery = subQuery.Where("? > LOWER(?)", bun.Ident("shortcode_domain"), maxShortcodeDomain) +	} + +	if minShortcodeDomain != "" { +		subQuery = subQuery.Where("? < LOWER(?)", bun.Ident("shortcode_domain"), minShortcodeDomain) +		// if we have a minShortcodeDomain we're paging upwards/backwards +		order = "DESC" +	} + +	subQuery = subQuery.Order("shortcode_domain " + order) + +	if limit > 0 { +		subQuery = subQuery.Limit(limit) +	} + +	// Wrap the subQuery in a query, since we don't need to select the shortcode_domain column. +	// +	// The final query will come out looking something like... +	// +	//	SELECT +	//		"subquery"."emoji_ids" +	//	FROM ( +	//		SELECT +	//			"emoji"."id" AS "emoji_ids", +	//			LOWER("emoji"."shortcode" || '@' || COALESCE("emoji"."domain", '')) AS "shortcode_domain" +	//		FROM +	//			"emojis" AS "emoji" +	//		ORDER BY +	//			"shortcode_domain" ASC +	//	) AS "subquery" +	if err := e.conn. +		NewSelect(). +		Column("subquery.emoji_ids"). +		TableExpr("(?) AS ?", subQuery, bun.Ident("subquery")). +		Scan(ctx, &emojiIDs); err != nil { +		return nil, e.conn.ProcessError(err) +	} + +	if order == "DESC" { +		// Reverse the slice order so the caller still +		// gets emojis in expected a-z alphabetical order. +		// +		// See https://github.com/golang/go/wiki/SliceTricks#reversing +		for i := len(emojiIDs)/2 - 1; i >= 0; i-- { +			opp := len(emojiIDs) - 1 - i +			emojiIDs[i], emojiIDs[opp] = emojiIDs[opp], emojiIDs[i] +		} +	} + +	return e.emojisFromIDs(ctx, emojiIDs) +} + +func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) {  	emojiIDs := []string{}  	q := e.conn. diff --git a/internal/db/bundb/emoji_test.go b/internal/db/bundb/emoji_test.go index 0a1546d91..3c61fb620 100644 --- a/internal/db/bundb/emoji_test.go +++ b/internal/db/bundb/emoji_test.go @@ -23,20 +23,108 @@ import (  	"testing"  	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/db"  )  type EmojiTestSuite struct {  	BunDBStandardTestSuite  } -func (suite *EmojiTestSuite) TestGetCustomEmojis() { -	emojis, err := suite.db.GetCustomEmojis(context.Background()) +func (suite *EmojiTestSuite) TestGetUseableEmojis() { +	emojis, err := suite.db.GetUseableEmojis(context.Background())  	suite.NoError(err)  	suite.Equal(1, len(emojis))  	suite.Equal("rainbow", emojis[0].Shortcode)  } +func (suite *EmojiTestSuite) TestGetAllEmojis() { +	emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "", "", 0) + +	suite.NoError(err) +	suite.Equal(2, len(emojis)) +	suite.Equal("rainbow", emojis[0].Shortcode) +	suite.Equal("yell", emojis[1].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetAllEmojisLimit1() { +	emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "", "", 1) + +	suite.NoError(err) +	suite.Equal(1, len(emojis)) +	suite.Equal("rainbow", emojis[0].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetAllEmojisMaxID() { +	emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "rainbow@", "", 0) + +	suite.NoError(err) +	suite.Equal(1, len(emojis)) +	suite.Equal("yell", emojis[0].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetAllEmojisMinID() { +	emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, true, "", "", "yell@fossbros-anonymous.io", 0) + +	suite.NoError(err) +	suite.Equal(1, len(emojis)) +	suite.Equal("rainbow", emojis[0].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetAllDisabledEmojis() { +	emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, true, false, "", "", "", 0) + +	suite.ErrorIs(err, db.ErrNoEntries) +	suite.Equal(0, len(emojis)) +} + +func (suite *EmojiTestSuite) TestGetAllEnabledEmojis() { +	emojis, err := suite.db.GetEmojis(context.Background(), db.EmojiAllDomains, false, true, "", "", "", 0) + +	suite.NoError(err) +	suite.Equal(2, len(emojis)) +	suite.Equal("rainbow", emojis[0].Shortcode) +	suite.Equal("yell", emojis[1].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetLocalEnabledEmojis() { +	emojis, err := suite.db.GetEmojis(context.Background(), "", false, true, "", "", "", 0) + +	suite.NoError(err) +	suite.Equal(1, len(emojis)) +	suite.Equal("rainbow", emojis[0].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetLocalDisabledEmojis() { +	emojis, err := suite.db.GetEmojis(context.Background(), "", true, false, "", "", "", 0) + +	suite.ErrorIs(err, db.ErrNoEntries) +	suite.Equal(0, len(emojis)) +} + +func (suite *EmojiTestSuite) TestGetAllEmojisFromDomain() { +	emojis, err := suite.db.GetEmojis(context.Background(), "peepee.poopoo", true, true, "", "", "", 0) + +	suite.ErrorIs(err, db.ErrNoEntries) +	suite.Equal(0, len(emojis)) +} + +func (suite *EmojiTestSuite) TestGetAllEmojisFromDomain2() { +	emojis, err := suite.db.GetEmojis(context.Background(), "fossbros-anonymous.io", true, true, "", "", "", 0) + +	suite.NoError(err) +	suite.Equal(1, len(emojis)) +	suite.Equal("yell", emojis[0].Shortcode) +} + +func (suite *EmojiTestSuite) TestGetSpecificEmojisFromDomain2() { +	emojis, err := suite.db.GetEmojis(context.Background(), "fossbros-anonymous.io", true, true, "yell", "", "", 0) + +	suite.NoError(err) +	suite.Equal(1, len(emojis)) +	suite.Equal("yell", emojis[0].Shortcode) +} +  func TestEmojiTestSuite(t *testing.T) {  	suite.Run(t, new(EmojiTestSuite))  }  | 
