diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/basic.go | 64 | ||||
| -rw-r--r-- | internal/db/bundb/basic_test.go | 21 | ||||
| -rw-r--r-- | internal/db/bundb/status.go | 8 | ||||
| -rw-r--r-- | internal/db/bundb/status_test.go | 30 | ||||
| -rw-r--r-- | internal/db/bundb/util.go | 63 | 
5 files changed, 155 insertions, 31 deletions
| diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index a3a8d0ae9..d4de5bb0b 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -24,6 +24,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/uptrace/bun"  ) @@ -53,17 +54,8 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})  	}  	q := b.conn.NewSelect().Model(i) -	for _, w := range where { -		if w.Value == nil { -			q = q.Where("? IS NULL", bun.Ident(w.Key)) -		} else { -			if w.CaseInsensitive { -				q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value) -			} else { -				q = q.Where("? = ?", bun.Safe(w.Key), w.Value) -			} -		} -	} + +	selectWhere(q, where)  	err := q.Scan(ctx)  	return b.conn.ProcessError(err) @@ -97,9 +89,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface  		NewDelete().  		Model(i) -	for _, w := range where { -		q = q.Where("? = ?", bun.Safe(w.Key), w.Value) -	} +	deleteWhere(q, where)  	_, err := q.Exec(ctx)  	return b.conn.ProcessError(err) @@ -128,17 +118,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu  func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {  	q := b.conn.NewUpdate().Model(i) -	for _, w := range where { -		if w.Value == nil { -			q = q.Where("? IS NULL", bun.Ident(w.Key)) -		} else { -			if w.CaseInsensitive { -				q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value) -			} else { -				q = q.Where("? = ?", bun.Safe(w.Key), w.Value) -			} -		} -	} +	updateWhere(q, where)  	q = q.Set("? = ?", bun.Safe(key), value) @@ -151,6 +131,40 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {  	return err  } +func (b *basicDB) CreateAllTables(ctx context.Context) db.Error { +	models := []interface{}{ +		>smodel.Account{}, +		>smodel.Application{}, +		>smodel.Block{}, +		>smodel.DomainBlock{}, +		>smodel.EmailDomainBlock{}, +		>smodel.Follow{}, +		>smodel.FollowRequest{}, +		>smodel.MediaAttachment{}, +		>smodel.Mention{}, +		>smodel.Status{}, +		>smodel.StatusToEmoji{}, +		>smodel.StatusToTag{}, +		>smodel.StatusFave{}, +		>smodel.StatusBookmark{}, +		>smodel.StatusMute{}, +		>smodel.Tag{}, +		>smodel.User{}, +		>smodel.Emoji{}, +		>smodel.Instance{}, +		>smodel.Notification{}, +		>smodel.RouterSession{}, +		>smodel.Token{}, +		>smodel.Client{}, +	} +	for _, i := range models { +		if err := b.CreateTable(ctx, i); err != nil { +			return err +		} +	} +	return nil +} +  func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {  	_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)  	return b.conn.ProcessError(err) diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go index d8067fb9d..e5f7e159a 100644 --- a/internal/db/bundb/basic_test.go +++ b/internal/db/bundb/basic_test.go @@ -23,6 +23,7 @@ import (  	"testing"  	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  ) @@ -42,7 +43,25 @@ func (suite *BasicTestSuite) TestGetAllStatuses() {  	s := []*gtsmodel.Status{}  	err := suite.db.GetAll(context.Background(), &s)  	suite.NoError(err) -	suite.Len(s, 12) +	suite.Len(s, 13) +} + +func (suite *BasicTestSuite) TestGetAllNotNull() { +	where := []db.Where{{ +		Key:   "domain", +		Value: nil, +		Not:   true, +	}} + +	a := []*gtsmodel.Account{} + +	err := suite.db.GetWhere(context.Background(), where, &a) +	suite.NoError(err) +	suite.NotEmpty(a) + +	for _, acct := range a { +		suite.NotEmpty(acct.Domain) +	}  }  func TestBasicTestSuite(t *testing.T) { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 9464cfadf..2c26a7df9 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -240,11 +240,11 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,  			}  		} -		// only do one loop if we only want direct children -		if onlyDirect { -			return +		// if we're not only looking for direct children of status, then do the same children-finding +		// operation for the found child status too. +		if !onlyDirect { +			s.statusChildren(ctx, child, foundStatuses, false, minID)  		} -		s.statusChildren(ctx, child, foundStatuses, false, minID)  	}  } diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go index 4b4a5aca4..64079c78f 100644 --- a/internal/db/bundb/status_test.go +++ b/internal/db/bundb/status_test.go @@ -43,10 +43,14 @@ func (suite *StatusTestSuite) TestGetStatusByID() {  	suite.Nil(status.BoostOfAccount)  	suite.Nil(status.InReplyTo)  	suite.Nil(status.InReplyToAccount) +	suite.True(status.Federated) +	suite.True(status.Boostable) +	suite.True(status.Replyable) +	suite.True(status.Likeable)  }  func (suite *StatusTestSuite) TestGetStatusByURI() { -	status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) +	status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_2_status_3"].URI)  	if err != nil {  		suite.FailNow(err.Error())  	} @@ -57,6 +61,10 @@ func (suite *StatusTestSuite) TestGetStatusByURI() {  	suite.Nil(status.BoostOfAccount)  	suite.Nil(status.InReplyTo)  	suite.Nil(status.InReplyToAccount) +	suite.True(status.Federated) +	suite.True(status.Boostable) +	suite.False(status.Replyable) +	suite.False(status.Likeable)  }  func (suite *StatusTestSuite) TestGetStatusWithExtras() { @@ -70,6 +78,10 @@ func (suite *StatusTestSuite) TestGetStatusWithExtras() {  	suite.NotEmpty(status.Tags)  	suite.NotEmpty(status.Attachments)  	suite.NotEmpty(status.Emojis) +	suite.True(status.Federated) +	suite.True(status.Boostable) +	suite.True(status.Replyable) +	suite.True(status.Likeable)  }  func (suite *StatusTestSuite) TestGetStatusWithMention() { @@ -83,6 +95,10 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {  	suite.NotEmpty(status.MentionIDs)  	suite.NotEmpty(status.InReplyToID)  	suite.NotEmpty(status.InReplyToAccountID) +	suite.True(status.Federated) +	suite.True(status.Boostable) +	suite.True(status.Replyable) +	suite.True(status.Likeable)  }  func (suite *StatusTestSuite) TestGetStatusTwice() { @@ -104,6 +120,18 @@ func (suite *StatusTestSuite) TestGetStatusTwice() {  	suite.Less(duration2, duration1)  } +func (suite *StatusTestSuite) TestGetStatusChildren() { +	targetStatus := suite.testStatuses["local_account_1_status_1"] +	children, err := suite.db.GetStatusChildren(context.Background(), targetStatus, true, "") +	suite.NoError(err) +	suite.Len(children, 2) +	for _, c := range children { +		suite.Equal(targetStatus.URI, c.InReplyToURI) +		suite.Equal(targetStatus.AccountID, c.InReplyToAccountID) +		suite.Equal(targetStatus.ID, c.InReplyToID) +	} +} +  func TestStatusTestSuite(t *testing.T) {  	suite.Run(t, new(StatusTestSuite))  } diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index 9e1afb87e..459f65d8c 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -19,6 +19,7 @@  package bundb  import ( +	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/uptrace/bun"  ) @@ -35,3 +36,65 @@ func whereEmptyOrNull(column string) func(*bun.SelectQuery) *bun.SelectQuery {  			WhereOr("? = ''", bun.Ident(column))  	}  } + +// updateWhere parses []db.Where and adds it to the given update query. +func updateWhere(q *bun.UpdateQuery, where []db.Where) { +	for _, w := range where { +		query, args := parseWhere(w) +		q = q.Where(query, args...) +	} +} + +// selectWhere parses []db.Where and adds it to the given select query. +func selectWhere(q *bun.SelectQuery, where []db.Where) { +	for _, w := range where { +		query, args := parseWhere(w) +		q = q.Where(query, args...) +	} +} + +// deleteWhere parses []db.Where and adds it to the given where query. +func deleteWhere(q *bun.DeleteQuery, where []db.Where) { +	for _, w := range where { +		query, args := parseWhere(w) +		q = q.Where(query, args...) +	} +} + +// parseWhere looks through the options on a single db.Where entry, and +// returns the appropriate query string and arguments. +func parseWhere(w db.Where) (query string, args []interface{}) { +	if w.Not { +		if w.Value == nil { +			query = "? IS NOT NULL" +			args = []interface{}{bun.Ident(w.Key)} +			return +		} + +		if w.CaseInsensitive { +			query = "LOWER(?) != LOWER(?)" +			args = []interface{}{bun.Safe(w.Key), w.Value} +			return +		} + +		query = "? != ?" +		args = []interface{}{bun.Safe(w.Key), w.Value} +		return +	} + +	if w.Value == nil { +		query = "? IS NULL" +		args = []interface{}{bun.Ident(w.Key)} +		return +	} + +	if w.CaseInsensitive { +		query = "LOWER(?) = LOWER(?)" +		args = []interface{}{bun.Safe(w.Key), w.Value} +		return +	} + +	query = "? = ?" +	args = []interface{}{bun.Safe(w.Key), w.Value} +	return +} | 
