summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/basic.go4
-rw-r--r--internal/db/bundb/basic.go64
-rw-r--r--internal/db/bundb/basic_test.go21
-rw-r--r--internal/db/bundb/status.go8
-rw-r--r--internal/db/bundb/status_test.go30
-rw-r--r--internal/db/bundb/util.go63
-rw-r--r--internal/db/params.go6
7 files changed, 164 insertions, 32 deletions
diff --git a/internal/db/basic.go b/internal/db/basic.go
index cf65ddc09..2a1141c8d 100644
--- a/internal/db/basic.go
+++ b/internal/db/basic.go
@@ -26,6 +26,10 @@ type Basic interface {
// For implementations that don't use tables, this can just return nil.
CreateTable(ctx context.Context, i interface{}) Error
+ // CreateAllTables creates *all* tables necessary for the running of GoToSocial.
+ // Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized.
+ CreateAllTables(ctx context.Context) Error
+
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(ctx context.Context, i interface{}) Error
diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go
index a3a8d0ae9..d4de5bb0b 100644
--- a/internal/db/bundb/basic.go
+++ b/internal/db/bundb/basic.go
@@ -24,6 +24,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
@@ -53,17 +54,8 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
}
q := b.conn.NewSelect().Model(i)
- for _, w := range where {
- if w.Value == nil {
- q = q.Where("? IS NULL", bun.Ident(w.Key))
- } else {
- if w.CaseInsensitive {
- q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
- } else {
- q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
- }
- }
- }
+
+ selectWhere(q, where)
err := q.Scan(ctx)
return b.conn.ProcessError(err)
@@ -97,9 +89,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface
NewDelete().
Model(i)
- for _, w := range where {
- q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
- }
+ deleteWhere(q, where)
_, err := q.Exec(ctx)
return b.conn.ProcessError(err)
@@ -128,17 +118,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
q := b.conn.NewUpdate().Model(i)
- for _, w := range where {
- if w.Value == nil {
- q = q.Where("? IS NULL", bun.Ident(w.Key))
- } else {
- if w.CaseInsensitive {
- q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
- } else {
- q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
- }
- }
- }
+ updateWhere(q, where)
q = q.Set("? = ?", bun.Safe(key), value)
@@ -151,6 +131,40 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
return err
}
+func (b *basicDB) CreateAllTables(ctx context.Context) db.Error {
+ models := []interface{}{
+ &gtsmodel.Account{},
+ &gtsmodel.Application{},
+ &gtsmodel.Block{},
+ &gtsmodel.DomainBlock{},
+ &gtsmodel.EmailDomainBlock{},
+ &gtsmodel.Follow{},
+ &gtsmodel.FollowRequest{},
+ &gtsmodel.MediaAttachment{},
+ &gtsmodel.Mention{},
+ &gtsmodel.Status{},
+ &gtsmodel.StatusToEmoji{},
+ &gtsmodel.StatusToTag{},
+ &gtsmodel.StatusFave{},
+ &gtsmodel.StatusBookmark{},
+ &gtsmodel.StatusMute{},
+ &gtsmodel.Tag{},
+ &gtsmodel.User{},
+ &gtsmodel.Emoji{},
+ &gtsmodel.Instance{},
+ &gtsmodel.Notification{},
+ &gtsmodel.RouterSession{},
+ &gtsmodel.Token{},
+ &gtsmodel.Client{},
+ }
+ for _, i := range models {
+ if err := b.CreateTable(ctx, i); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
return b.conn.ProcessError(err)
diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go
index d8067fb9d..e5f7e159a 100644
--- a/internal/db/bundb/basic_test.go
+++ b/internal/db/bundb/basic_test.go
@@ -23,6 +23,7 @@ import (
"testing"
"github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -42,7 +43,25 @@ func (suite *BasicTestSuite) TestGetAllStatuses() {
s := []*gtsmodel.Status{}
err := suite.db.GetAll(context.Background(), &s)
suite.NoError(err)
- suite.Len(s, 12)
+ suite.Len(s, 13)
+}
+
+func (suite *BasicTestSuite) TestGetAllNotNull() {
+ where := []db.Where{{
+ Key: "domain",
+ Value: nil,
+ Not: true,
+ }}
+
+ a := []*gtsmodel.Account{}
+
+ err := suite.db.GetWhere(context.Background(), where, &a)
+ suite.NoError(err)
+ suite.NotEmpty(a)
+
+ for _, acct := range a {
+ suite.NotEmpty(acct.Domain)
+ }
}
func TestBasicTestSuite(t *testing.T) {
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index 9464cfadf..2c26a7df9 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -240,11 +240,11 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,
}
}
- // only do one loop if we only want direct children
- if onlyDirect {
- return
+ // if we're not only looking for direct children of status, then do the same children-finding
+ // operation for the found child status too.
+ if !onlyDirect {
+ s.statusChildren(ctx, child, foundStatuses, false, minID)
}
- s.statusChildren(ctx, child, foundStatuses, false, minID)
}
}
diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go
index 4b4a5aca4..64079c78f 100644
--- a/internal/db/bundb/status_test.go
+++ b/internal/db/bundb/status_test.go
@@ -43,10 +43,14 @@ func (suite *StatusTestSuite) TestGetStatusByID() {
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
+ suite.True(status.Federated)
+ suite.True(status.Boostable)
+ suite.True(status.Replyable)
+ suite.True(status.Likeable)
}
func (suite *StatusTestSuite) TestGetStatusByURI() {
- status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
+ status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_2_status_3"].URI)
if err != nil {
suite.FailNow(err.Error())
}
@@ -57,6 +61,10 @@ func (suite *StatusTestSuite) TestGetStatusByURI() {
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
+ suite.True(status.Federated)
+ suite.True(status.Boostable)
+ suite.False(status.Replyable)
+ suite.False(status.Likeable)
}
func (suite *StatusTestSuite) TestGetStatusWithExtras() {
@@ -70,6 +78,10 @@ func (suite *StatusTestSuite) TestGetStatusWithExtras() {
suite.NotEmpty(status.Tags)
suite.NotEmpty(status.Attachments)
suite.NotEmpty(status.Emojis)
+ suite.True(status.Federated)
+ suite.True(status.Boostable)
+ suite.True(status.Replyable)
+ suite.True(status.Likeable)
}
func (suite *StatusTestSuite) TestGetStatusWithMention() {
@@ -83,6 +95,10 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {
suite.NotEmpty(status.MentionIDs)
suite.NotEmpty(status.InReplyToID)
suite.NotEmpty(status.InReplyToAccountID)
+ suite.True(status.Federated)
+ suite.True(status.Boostable)
+ suite.True(status.Replyable)
+ suite.True(status.Likeable)
}
func (suite *StatusTestSuite) TestGetStatusTwice() {
@@ -104,6 +120,18 @@ func (suite *StatusTestSuite) TestGetStatusTwice() {
suite.Less(duration2, duration1)
}
+func (suite *StatusTestSuite) TestGetStatusChildren() {
+ targetStatus := suite.testStatuses["local_account_1_status_1"]
+ children, err := suite.db.GetStatusChildren(context.Background(), targetStatus, true, "")
+ suite.NoError(err)
+ suite.Len(children, 2)
+ for _, c := range children {
+ suite.Equal(targetStatus.URI, c.InReplyToURI)
+ suite.Equal(targetStatus.AccountID, c.InReplyToAccountID)
+ suite.Equal(targetStatus.ID, c.InReplyToID)
+ }
+}
+
func TestStatusTestSuite(t *testing.T) {
suite.Run(t, new(StatusTestSuite))
}
diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go
index 9e1afb87e..459f65d8c 100644
--- a/internal/db/bundb/util.go
+++ b/internal/db/bundb/util.go
@@ -19,6 +19,7 @@
package bundb
import (
+ "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
)
@@ -35,3 +36,65 @@ func whereEmptyOrNull(column string) func(*bun.SelectQuery) *bun.SelectQuery {
WhereOr("? = ''", bun.Ident(column))
}
}
+
+// updateWhere parses []db.Where and adds it to the given update query.
+func updateWhere(q *bun.UpdateQuery, where []db.Where) {
+ for _, w := range where {
+ query, args := parseWhere(w)
+ q = q.Where(query, args...)
+ }
+}
+
+// selectWhere parses []db.Where and adds it to the given select query.
+func selectWhere(q *bun.SelectQuery, where []db.Where) {
+ for _, w := range where {
+ query, args := parseWhere(w)
+ q = q.Where(query, args...)
+ }
+}
+
+// deleteWhere parses []db.Where and adds it to the given where query.
+func deleteWhere(q *bun.DeleteQuery, where []db.Where) {
+ for _, w := range where {
+ query, args := parseWhere(w)
+ q = q.Where(query, args...)
+ }
+}
+
+// parseWhere looks through the options on a single db.Where entry, and
+// returns the appropriate query string and arguments.
+func parseWhere(w db.Where) (query string, args []interface{}) {
+ if w.Not {
+ if w.Value == nil {
+ query = "? IS NOT NULL"
+ args = []interface{}{bun.Ident(w.Key)}
+ return
+ }
+
+ if w.CaseInsensitive {
+ query = "LOWER(?) != LOWER(?)"
+ args = []interface{}{bun.Safe(w.Key), w.Value}
+ return
+ }
+
+ query = "? != ?"
+ args = []interface{}{bun.Safe(w.Key), w.Value}
+ return
+ }
+
+ if w.Value == nil {
+ query = "? IS NULL"
+ args = []interface{}{bun.Ident(w.Key)}
+ return
+ }
+
+ if w.CaseInsensitive {
+ query = "LOWER(?) = LOWER(?)"
+ args = []interface{}{bun.Safe(w.Key), w.Value}
+ return
+ }
+
+ query = "? = ?"
+ args = []interface{}{bun.Safe(w.Key), w.Value}
+ return
+}
diff --git a/internal/db/params.go b/internal/db/params.go
index f0c384435..dbbf734a1 100644
--- a/internal/db/params.go
+++ b/internal/db/params.go
@@ -22,9 +22,13 @@ package db
type Where struct {
// The table to search on.
Key string
- // The value that must be set.
+ // The value to match.
Value interface{}
// Whether the value (if a string) should be case sensitive or not.
// Defaults to false.
CaseInsensitive bool
+ // If set, reverse the where.
+ // `WHERE k = v` becomes `WHERE k != v`.
+ // `WHERE k IS NULL` becomes `WHERE k IS NOT NULL`
+ Not bool
}