summaryrefslogtreecommitdiff
path: root/internal/db/bundb
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb')
-rw-r--r--internal/db/bundb/basic.go34
-rw-r--r--internal/db/bundb/bundb.go5
-rw-r--r--internal/db/bundb/headerfilter.go207
-rw-r--r--internal/db/bundb/headerfilter_test.go125
-rw-r--r--internal/db/bundb/migrations/20231212144715_add_header_filters.go54
5 files changed, 391 insertions, 34 deletions
diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go
index e68903efa..488f59ad5 100644
--- a/internal/db/bundb/basic.go
+++ b/internal/db/bundb/basic.go
@@ -22,7 +22,6 @@ import (
"errors"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
@@ -120,39 +119,6 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error {
return err
}
-func (b *basicDB) CreateAllTables(ctx context.Context) error {
- models := []interface{}{
- &gtsmodel.Account{},
- &gtsmodel.Application{},
- &gtsmodel.Block{},
- &gtsmodel.DomainBlock{},
- &gtsmodel.EmailDomainBlock{},
- &gtsmodel.Follow{},
- &gtsmodel.FollowRequest{},
- &gtsmodel.MediaAttachment{},
- &gtsmodel.Mention{},
- &gtsmodel.Status{},
- &gtsmodel.StatusToEmoji{},
- &gtsmodel.StatusFave{},
- &gtsmodel.StatusBookmark{},
- &gtsmodel.ThreadMute{},
- &gtsmodel.Tag{},
- &gtsmodel.User{},
- &gtsmodel.Emoji{},
- &gtsmodel.Instance{},
- &gtsmodel.Notification{},
- &gtsmodel.RouterSession{},
- &gtsmodel.Token{},
- &gtsmodel.Client{},
- }
- for _, i := range models {
- if err := b.CreateTable(ctx, i); err != nil {
- return err
- }
- }
- return nil
-}
-
func (b *basicDB) DropTable(ctx context.Context, i interface{}) error {
_, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx)
return err
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index f7417cfeb..d9415eff4 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -67,6 +67,7 @@ type DBService struct {
db.Basic
db.Domain
db.Emoji
+ db.HeaderFilter
db.Instance
db.List
db.Marker
@@ -193,6 +194,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
+ HeaderFilter: &headerFilterDB{
+ db: db,
+ state: state,
+ },
Instance: &instanceDB{
db: db,
state: state,
diff --git a/internal/db/bundb/headerfilter.go b/internal/db/bundb/headerfilter.go
new file mode 100644
index 000000000..087b65c82
--- /dev/null
+++ b/internal/db/bundb/headerfilter.go
@@ -0,0 +1,207 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb
+
+import (
+ "context"
+ "net/http"
+ "time"
+ "unsafe"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/uptrace/bun"
+)
+
+type headerFilterDB struct {
+ db *DB
+ state *state.State
+}
+
+func (h *headerFilterDB) AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) {
+ return h.state.Caches.AllowHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
+ return h.GetAllowHeaderFilters(ctx)
+ })
+}
+
+func (h *headerFilterDB) AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) {
+ return h.state.Caches.AllowHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
+ return h.GetAllowHeaderFilters(ctx)
+ })
+}
+
+func (h *headerFilterDB) BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) {
+ return h.state.Caches.BlockHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
+ return h.GetBlockHeaderFilters(ctx)
+ })
+}
+
+func (h *headerFilterDB) BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) {
+ return h.state.Caches.BlockHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
+ return h.GetBlockHeaderFilters(ctx)
+ })
+}
+
+func (h *headerFilterDB) GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) {
+ filter := new(gtsmodel.HeaderFilterAllow)
+ if err := h.db.NewSelect().
+ Model(filter).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return fromAllowFilter(filter), nil
+}
+
+func (h *headerFilterDB) GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) {
+ filter := new(gtsmodel.HeaderFilterBlock)
+ if err := h.db.NewSelect().
+ Model(filter).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return fromBlockFilter(filter), nil
+}
+
+func (h *headerFilterDB) GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) {
+ var filters []*gtsmodel.HeaderFilterAllow
+ err := h.db.NewSelect().
+ Model(&filters).
+ Scan(ctx, &filters)
+ return fromAllowFilters(filters), err
+}
+
+func (h *headerFilterDB) GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) {
+ var filters []*gtsmodel.HeaderFilterBlock
+ err := h.db.NewSelect().
+ Model(&filters).
+ Scan(ctx, &filters)
+ return fromBlockFilters(filters), err
+}
+
+func (h *headerFilterDB) PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error {
+ if _, err := h.db.NewInsert().
+ Model(toAllowFilter(filter)).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.AllowHeaderFilters.Clear()
+ return nil
+}
+
+func (h *headerFilterDB) PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error {
+ if _, err := h.db.NewInsert().
+ Model(toBlockFilter(filter)).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.BlockHeaderFilters.Clear()
+ return nil
+}
+
+func (h *headerFilterDB) UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error {
+ filter.UpdatedAt = time.Now()
+ if len(cols) > 0 {
+ // If we're updating by column,
+ // ensure "updated_at" is included.
+ cols = append(cols, "updated_at")
+ }
+ if _, err := h.db.NewUpdate().
+ Model(toAllowFilter(filter)).
+ Column(cols...).
+ Where("? = ?", bun.Ident("id"), filter.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.AllowHeaderFilters.Clear()
+ return nil
+}
+
+func (h *headerFilterDB) UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error {
+ filter.UpdatedAt = time.Now()
+ if len(cols) > 0 {
+ // If we're updating by column,
+ // ensure "updated_at" is included.
+ cols = append(cols, "updated_at")
+ }
+ if _, err := h.db.NewUpdate().
+ Model(toBlockFilter(filter)).
+ Column(cols...).
+ Where("? = ?", bun.Ident("id"), filter.ID).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.BlockHeaderFilters.Clear()
+ return nil
+}
+
+func (h *headerFilterDB) DeleteAllowHeaderFilter(ctx context.Context, id string) error {
+ if _, err := h.db.NewDelete().
+ Table("header_filter_allows").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.AllowHeaderFilters.Clear()
+ return nil
+}
+
+func (h *headerFilterDB) DeleteBlockHeaderFilter(ctx context.Context, id string) error {
+ if _, err := h.db.NewDelete().
+ Table("header_filter_blocks").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return err
+ }
+ h.state.Caches.BlockHeaderFilters.Clear()
+ return nil
+}
+
+// NOTE:
+// all of the below unsafe cast functions
+// are only possible because HeaderFilterAllow{},
+// HeaderFilterBlock{}, HeaderFilter{} while
+// different types in source, have exactly the
+// same size and layout in memory. the unsafe
+// cast simply changes the type associated with
+// that block of memory.
+
+func toAllowFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterAllow {
+ return (*gtsmodel.HeaderFilterAllow)(unsafe.Pointer(filter))
+}
+
+func toBlockFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterBlock {
+ return (*gtsmodel.HeaderFilterBlock)(unsafe.Pointer(filter))
+}
+
+func fromAllowFilter(filter *gtsmodel.HeaderFilterAllow) *gtsmodel.HeaderFilter {
+ return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter))
+}
+
+func fromBlockFilter(filter *gtsmodel.HeaderFilterBlock) *gtsmodel.HeaderFilter {
+ return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter))
+}
+
+func fromAllowFilters(filters []*gtsmodel.HeaderFilterAllow) []*gtsmodel.HeaderFilter {
+ return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters))
+}
+
+func fromBlockFilters(filters []*gtsmodel.HeaderFilterBlock) []*gtsmodel.HeaderFilter {
+ return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters))
+}
diff --git a/internal/db/bundb/headerfilter_test.go b/internal/db/bundb/headerfilter_test.go
new file mode 100644
index 000000000..d7e2b26ee
--- /dev/null
+++ b/internal/db/bundb/headerfilter_test.go
@@ -0,0 +1,125 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type HeaderFilterTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *HeaderFilterTestSuite) TestAllowHeaderFilterGetPutUpdateDelete() {
+ suite.testHeaderFilterGetPutUpdateDelete(
+ suite.db.GetAllowHeaderFilter,
+ suite.db.GetAllowHeaderFilters,
+ suite.db.PutAllowHeaderFilter,
+ suite.db.UpdateAllowHeaderFilter,
+ suite.db.DeleteAllowHeaderFilter,
+ )
+}
+
+func (suite *HeaderFilterTestSuite) TestBlockHeaderFilterGetPutUpdateDelete() {
+ suite.testHeaderFilterGetPutUpdateDelete(
+ suite.db.GetBlockHeaderFilter,
+ suite.db.GetBlockHeaderFilters,
+ suite.db.PutBlockHeaderFilter,
+ suite.db.UpdateBlockHeaderFilter,
+ suite.db.DeleteBlockHeaderFilter,
+ )
+}
+
+func (suite *HeaderFilterTestSuite) testHeaderFilterGetPutUpdateDelete(
+ get func(context.Context, string) (*gtsmodel.HeaderFilter, error),
+ getAll func(context.Context) ([]*gtsmodel.HeaderFilter, error),
+ put func(context.Context, *gtsmodel.HeaderFilter) error,
+ update func(context.Context, *gtsmodel.HeaderFilter, ...string) error,
+ delete func(context.Context, string) error,
+) {
+ t := suite.T()
+
+ // Create new example header filter.
+ filter := gtsmodel.HeaderFilter{
+ ID: "some unique id",
+ Header: "Http-Header-Key",
+ Regex: ".*",
+ AuthorID: "some unique author id",
+ }
+
+ // Create new cancellable test context.
+ ctx := context.Background()
+ ctx, cncl := context.WithCancel(ctx)
+ defer cncl()
+
+ // Insert the example header filter into db.
+ if err := put(ctx, &filter); err != nil {
+ t.Fatalf("error inserting header filter: %v", err)
+ }
+
+ // Now fetch newly created filter.
+ check, err := get(ctx, filter.ID)
+ if err != nil {
+ t.Fatalf("error fetching header filter: %v", err)
+ }
+
+ // Check all expected fields match.
+ suite.Equal(filter.ID, check.ID)
+ suite.Equal(filter.Header, check.Header)
+ suite.Equal(filter.Regex, check.Regex)
+ suite.Equal(filter.AuthorID, check.AuthorID)
+
+ // Fetch all header filters.
+ all, err := getAll(ctx)
+ if err != nil {
+ t.Fatalf("error fetching header filters: %v", err)
+ }
+
+ // Ensure contains example.
+ suite.Equal(len(all), 1)
+ suite.Equal(all[0].ID, filter.ID)
+
+ // Update the header filter regex value.
+ check.Regex = "new regex value"
+ if err := update(ctx, check); err != nil {
+ t.Fatalf("error updating header filter: %v", err)
+ }
+
+ // Ensure 'updated_at' was updated on check model.
+ suite.True(check.UpdatedAt.After(filter.UpdatedAt))
+
+ // Now delete the header filter from db.
+ if err := delete(ctx, filter.ID); err != nil {
+ t.Fatalf("error deleting header filter: %v", err)
+ }
+
+ // Ensure we can't refetch it.
+ _, err = get(ctx, filter.ID)
+ if err != db.ErrNoEntries {
+ t.Fatalf("deleted header filter returned unexpected error: %v", err)
+ }
+}
+
+func TestHeaderFilterTestSuite(t *testing.T) {
+ suite.Run(t, new(HeaderFilterTestSuite))
+}
diff --git a/internal/db/bundb/migrations/20231212144715_add_header_filters.go b/internal/db/bundb/migrations/20231212144715_add_header_filters.go
new file mode 100644
index 000000000..2d671bf46
--- /dev/null
+++ b/internal/db/bundb/migrations/20231212144715_add_header_filters.go
@@ -0,0 +1,54 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package migrations
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ for _, model := range []any{
+ &gtsmodel.HeaderFilterAllow{},
+ &gtsmodel.HeaderFilterBlock{},
+ } {
+ _, err := db.NewCreateTable().
+ IfNotExists().
+ Model(model).
+ Exec(ctx)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}