summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/basic.go4
-rw-r--r--internal/db/bundb/basic.go34
-rw-r--r--internal/db/bundb/bundb.go5
-rw-r--r--internal/db/bundb/headerfilter.go207
-rw-r--r--internal/db/bundb/headerfilter_test.go125
-rw-r--r--internal/db/bundb/migrations/20231212144715_add_header_filters.go54
-rw-r--r--internal/db/db.go1
-rw-r--r--internal/db/headerfilter.go73
8 files changed, 465 insertions, 38 deletions
diff --git a/internal/db/basic.go b/internal/db/basic.go
index 7cd690aef..3a8e2af8d 100644
--- a/internal/db/basic.go
+++ b/internal/db/basic.go
@@ -25,10 +25,6 @@ type Basic interface {
// For implementations that don't use tables, this can just return nil.
CreateTable(ctx context.Context, i interface{}) error
- // CreateAllTables creates *all* tables necessary for the running of GoToSocial.
- // Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized.
- CreateAllTables(ctx context.Context) error
-
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(ctx context.Context, i interface{}) error
diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go
index e68903efa..488f59ad5 100644
--- a/internal/db/bundb/basic.go
+++ b/internal/db/bundb/basic.go
@@ -22,7 +22,6 @@ import (
"errors"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
@@ -120,39 +119,6 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error {
return err
}
-func (b *basicDB) CreateAllTables(ctx context.Context) error {
- models := []interface{}{
- &gtsmodel.Account{},
- &gtsmodel.Application{},
- &gtsmodel.Block{},
- &gtsmodel.DomainBlock{},
- &gtsmodel.EmailDomainBlock{},
- &gtsmodel.Follow{},
- &gtsmodel.FollowRequest{},
- &gtsmodel.MediaAttachment{},
- &gtsmodel.Mention{},
- &gtsmodel.Status{},
- &gtsmodel.StatusToEmoji{},
- &gtsmodel.StatusFave{},
- &gtsmodel.StatusBookmark{},
- &gtsmodel.ThreadMute{},
- &gtsmodel.Tag{},
- &gtsmodel.User{},
- &gtsmodel.Emoji{},
- &gtsmodel.Instance{},
- &gtsmodel.Notification{},
- &gtsmodel.RouterSession{},
- &gtsmodel.Token{},
- &gtsmodel.Client{},
- }
- for _, i := range models {
- if err := b.CreateTable(ctx, i); err != nil {
- return err
- }
- }
- return nil
-}
-
func (b *basicDB) DropTable(ctx context.Context, i interface{}) error {
_, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx)
return err
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index f7417cfeb..d9415eff4 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -67,6 +67,7 @@ type DBService struct {
db.Basic
db.Domain
db.Emoji
+ db.HeaderFilter
db.Instance
db.List
db.Marker
@@ -193,6 +194,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
+ HeaderFilter: &headerFilterDB{
+ db: db,
+ state: state,
+ },
Instance: &instanceDB{
db: db,
state: state,
diff --git a/internal/db/bundb/headerfilter.go b/internal/db/bundb/headerfilter.go
new file mode 100644
index 000000000..087b65c82
--- /dev/null
+++ b/internal/db/bundb/headerfilter.go
@@ -0,0 +1,207 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "net/http"
+ "time"
+ "unsafe"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/uptrace/bun"
+)
+
+type headerFilterDB struct {
+ db *DB
+ state *state.State
+}
+
+func (h *headerFilterDB) AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) {
+ return h.state.Caches.AllowHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
+ return h.GetAllowHeaderFilters(ctx)
+ })
+}
+
+func (h *headerFilterDB) AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) {
+ return h.state.Caches.AllowHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
+ return h.GetAllowHeaderFilters(ctx)
+ })
+}
+
+func (h *headerFilterDB) BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) {
+ return h.state.Caches.BlockHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
+ return h.GetBlockHeaderFilters(ctx)
+ })
+}
+
+func (h *headerFilterDB) BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) {
+ return h.state.Caches.BlockHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
+ return h.GetBlockHeaderFilters(ctx)
+ })
+}
+
+func (h *headerFilterDB) GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) {
+ filter := new(gtsmodel.HeaderFilterAllow)
+ if err := h.db.NewSelect().
+ Model(filter).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return fromAllowFilter(filter), nil
+}
+
+func (h *headerFilterDB) GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) {
+ filter := new(gtsmodel.HeaderFilterBlock)
+ if err := h.db.NewSelect().
+ Model(filter).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return fromBlockFilter(filter), nil
+}
+
+func (h *headerFilterDB) GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) {
+ var filters []*gtsmodel.HeaderFilterAllow
+ err := h.db.NewSelect().
+ Model(&filters).
+ Scan(ctx, &filters)
+ return fromAllowFilters(filters), err
+}
+
+func (h *headerFilterDB) GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) {
+ var filters []*gtsmodel.HeaderFilterBlock
+ err := h.db.NewSelect().
+ Model(&filters).
+ Scan(ctx, &filters)
+ return fromBlockFilters(filters), err
+}
+
+func (h *headerFilterDB) PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error {
+ if _, err := h.db.NewInsert().
+ Model(toAllowFilter(filter)).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.AllowHeaderFilters.Clear()
+ return nil
+}
+
+func (h *headerFilterDB) PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error {
+ if _, err := h.db.NewInsert().
+ Model(toBlockFilter(filter)).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.BlockHeaderFilters.Clear()
+ return nil
+}
+
+func (h *headerFilterDB) UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error {
+ filter.UpdatedAt = time.Now()
+ if len(cols) > 0 {
+ // If we're updating by column,
+ // ensure "updated_at" is included.
+ cols = append(cols, "updated_at")
+ }
+ if _, err := h.db.NewUpdate().
+ Model(toAllowFilter(filter)).
+ Column(cols...).
+ Where("? = ?", bun.Ident("id"), filter.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.AllowHeaderFilters.Clear()
+ return nil
+}
+
+func (h *headerFilterDB) UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error {
+ filter.UpdatedAt = time.Now()
+ if len(cols) > 0 {
+ // If we're updating by column,
+ // ensure "updated_at" is included.
+ cols = append(cols, "updated_at")
+ }
+ if _, err := h.db.NewUpdate().
+ Model(toBlockFilter(filter)).
+ Column(cols...).
+ Where("? = ?", bun.Ident("id"), filter.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.BlockHeaderFilters.Clear()
+ return nil
+}
+
+func (h *headerFilterDB) DeleteAllowHeaderFilter(ctx context.Context, id string) error {
+ if _, err := h.db.NewDelete().
+ Table("header_filter_allows").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.AllowHeaderFilters.Clear()
+ return nil
+}
+
+func (h *headerFilterDB) DeleteBlockHeaderFilter(ctx context.Context, id string) error {
+ if _, err := h.db.NewDelete().
+ Table("header_filter_blocks").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.BlockHeaderFilters.Clear()
+ return nil
+}
+
+// NOTE:
+// all of the below unsafe cast functions
+// are only possible because HeaderFilterAllow{},
+// HeaderFilterBlock{}, HeaderFilter{} while
+// different types in source, have exactly the
+// same size and layout in memory. the unsafe
+// cast simply changes the type associated with
+// that block of memory.
+
+func toAllowFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterAllow {
+ return (*gtsmodel.HeaderFilterAllow)(unsafe.Pointer(filter))
+}
+
+func toBlockFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterBlock {
+ return (*gtsmodel.HeaderFilterBlock)(unsafe.Pointer(filter))
+}
+
+func fromAllowFilter(filter *gtsmodel.HeaderFilterAllow) *gtsmodel.HeaderFilter {
+ return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter))
+}
+
+func fromBlockFilter(filter *gtsmodel.HeaderFilterBlock) *gtsmodel.HeaderFilter {
+ return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter))
+}
+
+func fromAllowFilters(filters []*gtsmodel.HeaderFilterAllow) []*gtsmodel.HeaderFilter {
+ return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters))
+}
+
+func fromBlockFilters(filters []*gtsmodel.HeaderFilterBlock) []*gtsmodel.HeaderFilter {
+ return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters))
+}
diff --git a/internal/db/bundb/headerfilter_test.go b/internal/db/bundb/headerfilter_test.go
new file mode 100644
index 000000000..d7e2b26ee
--- /dev/null
+++ b/internal/db/bundb/headerfilter_test.go
@@ -0,0 +1,125 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type HeaderFilterTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *HeaderFilterTestSuite) TestAllowHeaderFilterGetPutUpdateDelete() {
+ suite.testHeaderFilterGetPutUpdateDelete(
+ suite.db.GetAllowHeaderFilter,
+ suite.db.GetAllowHeaderFilters,
+ suite.db.PutAllowHeaderFilter,
+ suite.db.UpdateAllowHeaderFilter,
+ suite.db.DeleteAllowHeaderFilter,
+ )
+}
+
+func (suite *HeaderFilterTestSuite) TestBlockHeaderFilterGetPutUpdateDelete() {
+ suite.testHeaderFilterGetPutUpdateDelete(
+ suite.db.GetBlockHeaderFilter,
+ suite.db.GetBlockHeaderFilters,
+ suite.db.PutBlockHeaderFilter,
+ suite.db.UpdateBlockHeaderFilter,
+ suite.db.DeleteBlockHeaderFilter,
+ )
+}
+
+func (suite *HeaderFilterTestSuite) testHeaderFilterGetPutUpdateDelete(
+ get func(context.Context, string) (*gtsmodel.HeaderFilter, error),
+ getAll func(context.Context) ([]*gtsmodel.HeaderFilter, error),
+ put func(context.Context, *gtsmodel.HeaderFilter) error,
+ update func(context.Context, *gtsmodel.HeaderFilter, ...string) error,
+ delete func(context.Context, string) error,
+) {
+ t := suite.T()
+
+ // Create new example header filter.
+ filter := gtsmodel.HeaderFilter{
+ ID: "some unique id",
+ Header: "Http-Header-Key",
+ Regex: ".*",
+ AuthorID: "some unique author id",
+ }
+
+ // Create new cancellable test context.
+ ctx := context.Background()
+ ctx, cncl := context.WithCancel(ctx)
+ defer cncl()
+
+ // Insert the example header filter into db.
+ if err := put(ctx, &filter); err != nil {
+ t.Fatalf("error inserting header filter: %v", err)
+ }
+
+ // Now fetch newly created filter.
+ check, err := get(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error fetching header filter: %v", err)
+ }
+
+ // Check all expected fields match.
+ suite.Equal(filter.ID, check.ID)
+ suite.Equal(filter.Header, check.Header)
+ suite.Equal(filter.Regex, check.Regex)
+ suite.Equal(filter.AuthorID, check.AuthorID)
+
+ // Fetch all header filters.
+ all, err := getAll(ctx)
+ if err != nil {
+ t.Fatalf("error fetching header filters: %v", err)
+ }
+
+ // Ensure contains example.
+ suite.Equal(len(all), 1)
+ suite.Equal(all[0].ID, filter.ID)
+
+ // Update the header filter regex value.
+ check.Regex = "new regex value"
+ if err := update(ctx, check); err != nil {
+ t.Fatalf("error updating header filter: %v", err)
+ }
+
+ // Ensure 'updated_at' was updated on check model.
+ suite.True(check.UpdatedAt.After(filter.UpdatedAt))
+
+ // Now delete the header filter from db.
+ if err := delete(ctx, filter.ID); err != nil {
+ t.Fatalf("error deleting header filter: %v", err)
+ }
+
+ // Ensure we can't refetch it.
+ _, err = get(ctx, filter.ID)
+ if err != db.ErrNoEntries {
+ t.Fatalf("deleted header filter returned unexpected error: %v", err)
+ }
+}
+
+func TestHeaderFilterTestSuite(t *testing.T) {
+ suite.Run(t, new(HeaderFilterTestSuite))
+}
diff --git a/internal/db/bundb/migrations/20231212144715_add_header_filters.go b/internal/db/bundb/migrations/20231212144715_add_header_filters.go
new file mode 100644
index 000000000..2d671bf46
--- /dev/null
+++ b/internal/db/bundb/migrations/20231212144715_add_header_filters.go
@@ -0,0 +1,54 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ for _, model := range []any{
+ &gtsmodel.HeaderFilterAllow{},
+ &gtsmodel.HeaderFilterBlock{},
+ } {
+ _, err := db.NewCreateTable().
+ IfNotExists().
+ Model(model).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/db.go b/internal/db/db.go
index 2914d9b59..361687e94 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -30,6 +30,7 @@ type DB interface {
Basic
Domain
Emoji
+ HeaderFilter
Instance
List
Marker
diff --git a/internal/db/headerfilter.go b/internal/db/headerfilter.go
new file mode 100644
index 000000000..5fe8a5b17
--- /dev/null
+++ b/internal/db/headerfilter.go
@@ -0,0 +1,73 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package db
+
+import (
+ "context"
+ "net/http"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type HeaderFilter interface {
+ // AllowHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached allow header filters.
+ // (Note: the actual matching code can be found under ./internal/headerfilter/ ).
+ AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error)
+
+ // AllowHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached allow header filters.
+ // (Note: the actual matching code can be found under ./internal/headerfilter/ ).
+ AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error)
+
+ // BlockHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached block header filters.
+ // (Note: the actual matching code can be found under ./internal/headerfilter/ ).
+ BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error)
+
+ // BlockHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached block header filters.
+ // (Note: the actual matching code can be found under ./internal/headerfilter/ ).
+ BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error)
+
+ // GetAllowHeaderFilter fetches the allow header filter with ID from the database.
+ GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error)
+
+ // GetBlockHeaderFilter fetches the block header filter with ID from the database.
+ GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error)
+
+ // GetAllowHeaderFilters fetches all allow header filters from the database.
+ GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error)
+
+ // GetBlockHeaderFilters fetches all block header filters from the database.
+ GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error)
+
+ // PutAllowHeaderFilter inserts the given allow header filter into the database.
+ PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error
+
+ // PutBlockHeaderFilter inserts the given block header filter into the database.
+ PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error
+
+ // UpdateAllowHeaderFilter updates the given allow header filter in the database, only updating given columns if provided.
+ UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error
+
+ // UpdateBlockHeaderFilter updates the given block header filter in the database, only updating given columns if provided.
+ UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error
+
+ // DeleteAllowHeaderFilter deletes the allow header filter with ID from the database.
+ DeleteAllowHeaderFilter(ctx context.Context, id string) error
+
+ // DeleteBlockHeaderFilter deletes the block header filter with ID from the database.
+ DeleteBlockHeaderFilter(ctx context.Context, id string) error
+}