summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/account.go3
-rw-r--r--internal/db/bundb/conn.go5
-rw-r--r--internal/db/bundb/domain.go6
-rw-r--r--internal/db/bundb/domain_test.go57
-rw-r--r--internal/db/bundb/relationship.go19
-rw-r--r--internal/db/bundb/relationship_test.go41
-rw-r--r--internal/db/bundb/status.go6
-rw-r--r--internal/db/bundb/trace.go5
8 files changed, 128 insertions, 14 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 876fb5186..59292055e 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -22,6 +22,7 @@ import (
"context"
"errors"
"fmt"
+ "strings"
"time"
"github.com/spf13/viper"
@@ -199,7 +200,7 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri
account := new(gtsmodel.Account)
q := a.newAccountQ(account).
- Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username). // ignore casing
+ Where("username = ?", strings.ToLower(username)). // usernames on our instance will always be lowercase
WhereGroup(" AND ", whereEmptyOrNull("domain"))
if err := q.Scan(ctx); err != nil {
diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go
index 3b5a3ac92..baa0baeae 100644
--- a/internal/db/bundb/conn.go
+++ b/internal/db/bundb/conn.go
@@ -68,13 +68,12 @@ func (conn *DBConn) ProcessError(err error) db.Error {
// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors
func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
- // Get the select query result
- count, err := query.Count(ctx)
+ exists, err := query.Exists(ctx)
// Process error as our own and check if it exists
switch err := conn.ProcessError(err); err {
case nil:
- return (count != 0), nil
+ return exists, nil
case db.ErrNoEntries:
return false, nil
default:
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
index 417b2becd..e63a584bd 100644
--- a/internal/db/bundb/domain.go
+++ b/internal/db/bundb/domain.go
@@ -21,6 +21,7 @@ package bundb
import (
"context"
"net/url"
+ "strings"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@@ -39,7 +40,8 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db
q := d.conn.
NewSelect().
Model(&gtsmodel.DomainBlock{}).
- Where("LOWER(domain) = LOWER(?)", domain).
+ ExcludeColumn("id", "created_at", "updated_at", "created_by_account_id", "private_comment", "public_comment", "obfuscate", "subscription_id").
+ Where("domain = ?", domain).
Limit(1)
return d.conn.Exists(ctx, q)
@@ -50,7 +52,7 @@ func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (boo
uniqueDomains := util.UniqueStrings(domains)
for _, domain := range uniqueDomains {
- if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil {
+ if blocked, err := d.IsDomainBlocked(ctx, strings.ToLower(domain)); err != nil {
return false, err
} else if blocked {
return blocked, nil
diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go
new file mode 100644
index 000000000..1a3fed24d
--- /dev/null
+++ b/internal/db/bundb/domain_test.go
@@ -0,0 +1,57 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type DomainTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *DomainTestSuite) TestIsDomainBlocked() {
+ ctx := context.Background()
+
+ domainBlock := &gtsmodel.DomainBlock{
+ ID: "01G204214Y9TNJEBX39C7G88SW",
+ Domain: "some.bad.apples",
+ CreatedByAccountID: suite.testAccounts["admin_account"].ID,
+ }
+
+ // no domain block exists for the given domain yet
+ blocked, err := suite.db.IsDomainBlocked(ctx, domainBlock.Domain)
+ suite.NoError(err)
+ suite.False(blocked)
+
+ suite.db.Put(ctx, domainBlock)
+
+ // domain block now exists
+ blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain)
+ suite.NoError(err)
+ suite.True(blocked)
+}
+
+func TestDomainTestSuite(t *testing.T) {
+ suite.Run(t, new(DomainTestSuite))
+}
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index 369553205..e2e2c96b2 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -52,14 +52,25 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account
q := r.conn.
NewSelect().
Model(&gtsmodel.Block{}).
- Where("account_id = ?", account1).
- Where("target_account_id = ?", account2).
+ ExcludeColumn("id", "created_at", "updated_at", "uri").
Limit(1)
if eitherDirection {
q = q.
- WhereOr("target_account_id = ?", account1).
- Where("account_id = ?", account2)
+ WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
+ return inner.
+ Where("account_id = ?", account1).
+ Where("target_account_id = ?", account2)
+ }).
+ WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
+ return inner.
+ Where("account_id = ?", account2).
+ Where("target_account_id = ?", account1)
+ })
+ } else {
+ q = q.
+ Where("account_id = ?", account1).
+ Where("target_account_id = ?", account2)
}
return r.conn.Exists(ctx, q)
diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go
index bb0f0e3da..34fe85a57 100644
--- a/internal/db/bundb/relationship_test.go
+++ b/internal/db/bundb/relationship_test.go
@@ -25,6 +25,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type RelationshipTestSuite struct {
@@ -32,7 +33,45 @@ type RelationshipTestSuite struct {
}
func (suite *RelationshipTestSuite) TestIsBlocked() {
- suite.Suite.T().Skip("TODO: implement")
+ ctx := context.Background()
+
+ account1 := suite.testAccounts["local_account_1"].ID
+ account2 := suite.testAccounts["local_account_2"].ID
+
+ // no blocks exist between account 1 and account 2
+ blocked, err := suite.db.IsBlocked(ctx, account1, account2, false)
+ suite.NoError(err)
+ suite.False(blocked)
+
+ blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
+ suite.NoError(err)
+ suite.False(blocked)
+
+ // have account1 block account2
+ suite.db.Put(ctx, &gtsmodel.Block{
+ ID: "01G202BCSXXJZ70BHB5KCAHH8C",
+ URI: "http://localhost:8080/some_block_uri_1",
+ AccountID: account1,
+ TargetAccountID: account2,
+ })
+
+ // account 1 now blocks account 2
+ blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
+ suite.NoError(err)
+ suite.True(blocked)
+
+ // account 2 doesn't block account 1
+ blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
+ suite.NoError(err)
+ suite.False(blocked)
+
+ // a block exists in either direction between the two
+ blocked, err = suite.db.IsBlocked(ctx, account1, account2, true)
+ suite.NoError(err)
+ suite.True(blocked)
+ blocked, err = suite.db.IsBlocked(ctx, account2, account1, true)
+ suite.NoError(err)
+ suite.True(blocked)
}
func (suite *RelationshipTestSuite) TestGetBlock() {
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index 1783723bb..4e670f59b 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -70,7 +70,7 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat
return s.cache.GetByID(id)
},
func(status *gtsmodel.Status) error {
- return s.newStatusQ(status).Where("LOWER(status.id) = LOWER(?)", id).Scan(ctx)
+ return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx)
},
)
}
@@ -82,7 +82,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
return s.cache.GetByURI(uri)
},
func(status *gtsmodel.Status) error {
- return s.newStatusQ(status).Where("LOWER(status.uri) = LOWER(?)", uri).Scan(ctx)
+ return s.newStatusQ(status).Where("status.uri = ?", uri).Scan(ctx)
},
)
}
@@ -94,7 +94,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St
return s.cache.GetByURL(url)
},
func(status *gtsmodel.Status) error {
- return s.newStatusQ(status).Where("LOWER(status.url) = LOWER(?)", url).Scan(ctx)
+ return s.newStatusQ(status).Where("status.url = ?", url).Scan(ctx)
},
)
}
diff --git a/internal/db/bundb/trace.go b/internal/db/bundb/trace.go
index 93c231785..27b5e22ac 100644
--- a/internal/db/bundb/trace.go
+++ b/internal/db/bundb/trace.go
@@ -47,6 +47,11 @@ func (q *debugQueryHook) AfterQuery(_ context.Context, event *bun.QueryEvent) {
"operation": event.Operation(),
})
+ if dur > 1*time.Second {
+ l.Warnf("SLOW DATABASE QUERY [%s] %s", dur, event.Query)
+ return
+ }
+
if logrus.GetLevel() == logrus.TraceLevel {
l.Tracef("[%s] %s", dur, event.Query)
} else {