diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/basic.go | 4 | ||||
-rw-r--r-- | internal/db/bundb/basic.go | 64 | ||||
-rw-r--r-- | internal/db/bundb/basic_test.go | 21 | ||||
-rw-r--r-- | internal/db/bundb/status.go | 8 | ||||
-rw-r--r-- | internal/db/bundb/status_test.go | 30 | ||||
-rw-r--r-- | internal/db/bundb/util.go | 63 | ||||
-rw-r--r-- | internal/db/params.go | 6 |
7 files changed, 164 insertions, 32 deletions
diff --git a/internal/db/basic.go b/internal/db/basic.go index cf65ddc09..2a1141c8d 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -26,6 +26,10 @@ type Basic interface { // For implementations that don't use tables, this can just return nil. CreateTable(ctx context.Context, i interface{}) Error + // CreateAllTables creates *all* tables necessary for the running of GoToSocial. + // Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized. + CreateAllTables(ctx context.Context) Error + // DropTable drops the table for the given interface. // For implementations that don't use tables, this can just return nil. DropTable(ctx context.Context, i interface{}) Error diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index a3a8d0ae9..d4de5bb0b 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -24,6 +24,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/uptrace/bun" ) @@ -53,17 +54,8 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) } q := b.conn.NewSelect().Model(i) - for _, w := range where { - if w.Value == nil { - q = q.Where("? IS NULL", bun.Ident(w.Key)) - } else { - if w.CaseInsensitive { - q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value) - } else { - q = q.Where("? = ?", bun.Safe(w.Key), w.Value) - } - } - } + + selectWhere(q, where) err := q.Scan(ctx) return b.conn.ProcessError(err) @@ -97,9 +89,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface NewDelete(). Model(i) - for _, w := range where { - q = q.Where("? = ?", bun.Safe(w.Key), w.Value) - } + deleteWhere(q, where) _, err := q.Exec(ctx) return b.conn.ProcessError(err) @@ -128,17 +118,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error { q := b.conn.NewUpdate().Model(i) - for _, w := range where { - if w.Value == nil { - q = q.Where("? IS NULL", bun.Ident(w.Key)) - } else { - if w.CaseInsensitive { - q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value) - } else { - q = q.Where("? = ?", bun.Safe(w.Key), w.Value) - } - } - } + updateWhere(q, where) q = q.Set("? = ?", bun.Safe(key), value) @@ -151,6 +131,40 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { return err } +func (b *basicDB) CreateAllTables(ctx context.Context) db.Error { + models := []interface{}{ + >smodel.Account{}, + >smodel.Application{}, + >smodel.Block{}, + >smodel.DomainBlock{}, + >smodel.EmailDomainBlock{}, + >smodel.Follow{}, + >smodel.FollowRequest{}, + >smodel.MediaAttachment{}, + >smodel.Mention{}, + >smodel.Status{}, + >smodel.StatusToEmoji{}, + >smodel.StatusToTag{}, + >smodel.StatusFave{}, + >smodel.StatusBookmark{}, + >smodel.StatusMute{}, + >smodel.Tag{}, + >smodel.User{}, + >smodel.Emoji{}, + >smodel.Instance{}, + >smodel.Notification{}, + >smodel.RouterSession{}, + >smodel.Token{}, + >smodel.Client{}, + } + for _, i := range models { + if err := b.CreateTable(ctx, i); err != nil { + return err + } + } + return nil +} + func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error { _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx) return b.conn.ProcessError(err) diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go index d8067fb9d..e5f7e159a 100644 --- a/internal/db/bundb/basic_test.go +++ b/internal/db/bundb/basic_test.go @@ -23,6 +23,7 @@ import ( "testing" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -42,7 +43,25 @@ func (suite *BasicTestSuite) TestGetAllStatuses() { s := []*gtsmodel.Status{} err := suite.db.GetAll(context.Background(), &s) suite.NoError(err) - suite.Len(s, 12) + suite.Len(s, 13) +} + +func (suite *BasicTestSuite) TestGetAllNotNull() { + where := []db.Where{{ + Key: "domain", + Value: nil, + Not: true, + }} + + a := []*gtsmodel.Account{} + + err := suite.db.GetWhere(context.Background(), where, &a) + suite.NoError(err) + suite.NotEmpty(a) + + for _, acct := range a { + suite.NotEmpty(acct.Domain) + } } func TestBasicTestSuite(t *testing.T) { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 9464cfadf..2c26a7df9 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -240,11 +240,11 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, } } - // only do one loop if we only want direct children - if onlyDirect { - return + // if we're not only looking for direct children of status, then do the same children-finding + // operation for the found child status too. + if !onlyDirect { + s.statusChildren(ctx, child, foundStatuses, false, minID) } - s.statusChildren(ctx, child, foundStatuses, false, minID) } } diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go index 4b4a5aca4..64079c78f 100644 --- a/internal/db/bundb/status_test.go +++ b/internal/db/bundb/status_test.go @@ -43,10 +43,14 @@ func (suite *StatusTestSuite) TestGetStatusByID() { suite.Nil(status.BoostOfAccount) suite.Nil(status.InReplyTo) suite.Nil(status.InReplyToAccount) + suite.True(status.Federated) + suite.True(status.Boostable) + suite.True(status.Replyable) + suite.True(status.Likeable) } func (suite *StatusTestSuite) TestGetStatusByURI() { - status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) + status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_2_status_3"].URI) if err != nil { suite.FailNow(err.Error()) } @@ -57,6 +61,10 @@ func (suite *StatusTestSuite) TestGetStatusByURI() { suite.Nil(status.BoostOfAccount) suite.Nil(status.InReplyTo) suite.Nil(status.InReplyToAccount) + suite.True(status.Federated) + suite.True(status.Boostable) + suite.False(status.Replyable) + suite.False(status.Likeable) } func (suite *StatusTestSuite) TestGetStatusWithExtras() { @@ -70,6 +78,10 @@ func (suite *StatusTestSuite) TestGetStatusWithExtras() { suite.NotEmpty(status.Tags) suite.NotEmpty(status.Attachments) suite.NotEmpty(status.Emojis) + suite.True(status.Federated) + suite.True(status.Boostable) + suite.True(status.Replyable) + suite.True(status.Likeable) } func (suite *StatusTestSuite) TestGetStatusWithMention() { @@ -83,6 +95,10 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() { suite.NotEmpty(status.MentionIDs) suite.NotEmpty(status.InReplyToID) suite.NotEmpty(status.InReplyToAccountID) + suite.True(status.Federated) + suite.True(status.Boostable) + suite.True(status.Replyable) + suite.True(status.Likeable) } func (suite *StatusTestSuite) TestGetStatusTwice() { @@ -104,6 +120,18 @@ func (suite *StatusTestSuite) TestGetStatusTwice() { suite.Less(duration2, duration1) } +func (suite *StatusTestSuite) TestGetStatusChildren() { + targetStatus := suite.testStatuses["local_account_1_status_1"] + children, err := suite.db.GetStatusChildren(context.Background(), targetStatus, true, "") + suite.NoError(err) + suite.Len(children, 2) + for _, c := range children { + suite.Equal(targetStatus.URI, c.InReplyToURI) + suite.Equal(targetStatus.AccountID, c.InReplyToAccountID) + suite.Equal(targetStatus.ID, c.InReplyToID) + } +} + func TestStatusTestSuite(t *testing.T) { suite.Run(t, new(StatusTestSuite)) } diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index 9e1afb87e..459f65d8c 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -19,6 +19,7 @@ package bundb import ( + "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/uptrace/bun" ) @@ -35,3 +36,65 @@ func whereEmptyOrNull(column string) func(*bun.SelectQuery) *bun.SelectQuery { WhereOr("? = ''", bun.Ident(column)) } } + +// updateWhere parses []db.Where and adds it to the given update query. +func updateWhere(q *bun.UpdateQuery, where []db.Where) { + for _, w := range where { + query, args := parseWhere(w) + q = q.Where(query, args...) + } +} + +// selectWhere parses []db.Where and adds it to the given select query. +func selectWhere(q *bun.SelectQuery, where []db.Where) { + for _, w := range where { + query, args := parseWhere(w) + q = q.Where(query, args...) + } +} + +// deleteWhere parses []db.Where and adds it to the given where query. +func deleteWhere(q *bun.DeleteQuery, where []db.Where) { + for _, w := range where { + query, args := parseWhere(w) + q = q.Where(query, args...) + } +} + +// parseWhere looks through the options on a single db.Where entry, and +// returns the appropriate query string and arguments. +func parseWhere(w db.Where) (query string, args []interface{}) { + if w.Not { + if w.Value == nil { + query = "? IS NOT NULL" + args = []interface{}{bun.Ident(w.Key)} + return + } + + if w.CaseInsensitive { + query = "LOWER(?) != LOWER(?)" + args = []interface{}{bun.Safe(w.Key), w.Value} + return + } + + query = "? != ?" + args = []interface{}{bun.Safe(w.Key), w.Value} + return + } + + if w.Value == nil { + query = "? IS NULL" + args = []interface{}{bun.Ident(w.Key)} + return + } + + if w.CaseInsensitive { + query = "LOWER(?) = LOWER(?)" + args = []interface{}{bun.Safe(w.Key), w.Value} + return + } + + query = "? = ?" + args = []interface{}{bun.Safe(w.Key), w.Value} + return +} diff --git a/internal/db/params.go b/internal/db/params.go index f0c384435..dbbf734a1 100644 --- a/internal/db/params.go +++ b/internal/db/params.go @@ -22,9 +22,13 @@ package db type Where struct { // The table to search on. Key string - // The value that must be set. + // The value to match. Value interface{} // Whether the value (if a string) should be case sensitive or not. // Defaults to false. CaseInsensitive bool + // If set, reverse the where. + // `WHERE k = v` becomes `WHERE k != v`. + // `WHERE k IS NULL` becomes `WHERE k IS NOT NULL` + Not bool } |