summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/domain.go148
-rw-r--r--internal/db/bundb/domain_test.go53
-rw-r--r--internal/db/bundb/migrations/20230908083121_allowlist.go.go62
-rw-r--r--internal/db/domain.go34
4 files changed, 291 insertions, 6 deletions
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
index c989d4fe4..dd626bc0a 100644
--- a/internal/db/bundb/domain.go
+++ b/internal/db/bundb/domain.go
@@ -23,6 +23,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
@@ -34,6 +35,102 @@ type domainDB struct {
state *state.State
}
+func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error {
+ // Normalize the domain as punycode
+ var err error
+ allow.Domain, err = util.Punify(allow.Domain)
+ if err != nil {
+ return err
+ }
+
+ // Attempt to store domain allow in DB
+ if _, err := d.db.NewInsert().
+ Model(allow).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Clear the domain allow cache (for later reload)
+ d.state.Caches.GTS.DomainAllow().Clear()
+
+ return nil
+}
+
+func (d *domainDB) GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error) {
+ // Normalize the domain as punycode
+ domain, err := util.Punify(domain)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check for easy case, domain referencing *us*
+ if domain == "" || domain == config.GetAccountDomain() ||
+ domain == config.GetHost() {
+ return nil, db.ErrNoEntries
+ }
+
+ var allow gtsmodel.DomainAllow
+
+ // Look for allow matching domain in DB
+ q := d.db.
+ NewSelect().
+ Model(&allow).
+ Where("? = ?", bun.Ident("domain_allow.domain"), domain)
+ if err := q.Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return &allow, nil
+}
+
+func (d *domainDB) GetDomainAllows(ctx context.Context) ([]*gtsmodel.DomainAllow, error) {
+ allows := []*gtsmodel.DomainAllow{}
+
+ if err := d.db.
+ NewSelect().
+ Model(&allows).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return allows, nil
+}
+
+func (d *domainDB) GetDomainAllowByID(ctx context.Context, id string) (*gtsmodel.DomainAllow, error) {
+ var allow gtsmodel.DomainAllow
+
+ q := d.db.
+ NewSelect().
+ Model(&allow).
+ Where("? = ?", bun.Ident("domain_allow.id"), id)
+ if err := q.Scan(ctx); err != nil {
+ return nil, err
+ }
+
+ return &allow, nil
+}
+
+func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error {
+ // Normalize the domain as punycode
+ domain, err := util.Punify(domain)
+ if err != nil {
+ return err
+ }
+
+ // Attempt to delete domain allow
+ if _, err := d.db.NewDelete().
+ Model((*gtsmodel.DomainAllow)(nil)).
+ Where("? = ?", bun.Ident("domain_allow.domain"), domain).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Clear the domain allow cache (for later reload)
+ d.state.Caches.GTS.DomainAllow().Clear()
+
+ return nil
+}
+
func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error {
// Normalize the domain as punycode
var err error
@@ -137,14 +234,32 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er
return false, err
}
- // Check for easy case, domain referencing *us*
+ // Domain referencing *us* cannot be blocked.
if domain == "" || domain == config.GetAccountDomain() ||
domain == config.GetHost() {
return false, nil
}
+ // Check the cache for an explicit domain allow (hydrating the cache with callback if necessary).
+ explicitAllow, err := d.state.Caches.GTS.DomainAllow().Matches(domain, func() ([]string, error) {
+ var domains []string
+
+ // Scan list of all explicitly allowed domains from DB
+ q := d.db.NewSelect().
+ Table("domain_allows").
+ Column("domain")
+ if err := q.Scan(ctx, &domains); err != nil {
+ return nil, err
+ }
+
+ return domains, nil
+ })
+ if err != nil {
+ return false, err
+ }
+
// Check the cache for a domain block (hydrating the cache with callback if necessary)
- return d.state.Caches.GTS.DomainBlock().IsBlocked(domain, func() ([]string, error) {
+ explicitBlock, err := d.state.Caches.GTS.DomainBlock().Matches(domain, func() ([]string, error) {
var domains []string
// Scan list of all blocked domains from DB
@@ -157,6 +272,35 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er
return domains, nil
})
+ if err != nil {
+ return false, err
+ }
+
+ // Calculate if blocked
+ // based on federation mode.
+ switch mode := config.GetInstanceFederationMode(); mode {
+
+ case config.InstanceFederationModeBlocklist:
+ // Blocklist/default mode: explicit allow
+ // takes precedence over explicit block.
+ //
+ // Domains that have neither block
+ // or allow entries are allowed.
+ return !(explicitAllow || !explicitBlock), nil
+
+ case config.InstanceFederationModeAllowlist:
+ // Allowlist mode: explicit block takes
+ // precedence over explicit allow.
+ //
+ // Domains that have neither block
+ // or allow entries are blocked.
+ return (explicitBlock || !explicitAllow), nil
+
+ default:
+ // This should never happen but account
+ // for it anyway to make the code tidier.
+ return false, gtserror.Newf("unrecognized federation mode: %s", mode)
+ }
}
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, error) {
diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go
index e4e199fa1..ff687cf59 100644
--- a/internal/db/bundb/domain_test.go
+++ b/internal/db/bundb/domain_test.go
@@ -55,6 +55,59 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() {
suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second)
}
+func (suite *DomainTestSuite) TestIsDomainBlockedWithAllow() {
+ ctx := context.Background()
+
+ domainBlock := &gtsmodel.DomainBlock{
+ ID: "01G204214Y9TNJEBX39C7G88SW",
+ Domain: "some.bad.apples",
+ CreatedByAccountID: suite.testAccounts["admin_account"].ID,
+ CreatedByAccount: suite.testAccounts["admin_account"],
+ }
+
+ // no domain block exists for the given domain yet
+ blocked, err := suite.db.IsDomainBlocked(ctx, domainBlock.Domain)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.False(blocked)
+
+ // Block this domain.
+ if err := suite.db.CreateDomainBlock(ctx, domainBlock); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // domain block now exists
+ blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.True(blocked)
+ suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second)
+
+ // Explicitly allow this domain.
+ domainAllow := &gtsmodel.DomainAllow{
+ ID: "01H8KY9MJQFWE712EG3VN02Y3J",
+ Domain: "some.bad.apples",
+ CreatedByAccountID: suite.testAccounts["admin_account"].ID,
+ CreatedByAccount: suite.testAccounts["admin_account"],
+ }
+
+ if err := suite.db.CreateDomainAllow(ctx, domainAllow); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Domain allow now exists
+ blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.False(blocked)
+}
+
func (suite *DomainTestSuite) TestIsDomainBlockedWildcard() {
ctx := context.Background()
diff --git a/internal/db/bundb/migrations/20230908083121_allowlist.go.go b/internal/db/bundb/migrations/20230908083121_allowlist.go.go
new file mode 100644
index 000000000..2d86f8c03
--- /dev/null
+++ b/internal/db/bundb/migrations/20230908083121_allowlist.go.go
@@ -0,0 +1,62 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+
+ gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Create domain allow.
+ if _, err := tx.
+ NewCreateTable().
+ Model(&gtsmodel.DomainAllow{}).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Index domain allow.
+ if _, err := tx.
+ NewCreateIndex().
+ Table("domain_allows").
+ Index("domain_allows_domain_idx").
+ Column("domain").
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/domain.go b/internal/db/domain.go
index 740ccefe6..3f7803d62 100644
--- a/internal/db/domain.go
+++ b/internal/db/domain.go
@@ -26,6 +26,25 @@ import (
// Domain contains DB functions related to domains and domain blocks.
type Domain interface {
+ /*
+ Block/allow storage + retrieval functions.
+ */
+
+ // CreateDomainAllow puts the given instance-level domain allow into the database.
+ CreateDomainAllow(ctx context.Context, allow *gtsmodel.DomainAllow) error
+
+ // GetDomainAllow returns one instance-level domain allow with the given domain, if it exists.
+ GetDomainAllow(ctx context.Context, domain string) (*gtsmodel.DomainAllow, error)
+
+ // GetDomainAllowByID returns one instance-level domain allow with the given id, if it exists.
+ GetDomainAllowByID(ctx context.Context, id string) (*gtsmodel.DomainAllow, error)
+
+ // GetDomainAllows returns all instance-level domain allows currently enforced by this instance.
+ GetDomainAllows(ctx context.Context) ([]*gtsmodel.DomainAllow, error)
+
+ // DeleteDomainAllow deletes an instance-level domain allow with the given domain, if it exists.
+ DeleteDomainAllow(ctx context.Context, domain string) error
+
// CreateDomainBlock puts the given instance-level domain block into the database.
CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error
@@ -41,15 +60,22 @@ type Domain interface {
// DeleteDomainBlock deletes an instance-level domain block with the given domain, if it exists.
DeleteDomainBlock(ctx context.Context, domain string) error
- // IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
+ /*
+ Block/allow checking functions.
+ */
+
+ // IsDomainBlocked checks if domain is blocked, accounting for both explicit allows and blocks.
+ // Will check allows first, so an allowed domain will always return false, even if it's also blocked.
IsDomainBlocked(ctx context.Context, domain string) (bool, error)
- // AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
+ // AreDomainsBlocked calls IsDomainBlocked for each domain.
+ // Will return true if even one of the given domains is blocked.
AreDomainsBlocked(ctx context.Context, domains []string) (bool, error)
- // IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
+ // IsURIBlocked calls IsDomainBlocked for the host of the given URI.
IsURIBlocked(ctx context.Context, uri *url.URL) (bool, error)
- // AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
+ // AreURIsBlocked calls IsURIBlocked for each URI.
+ // Will return true if even one of the given URIs is blocked.
AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, error)
}