diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/basic.go | 4 | ||||
-rw-r--r-- | internal/db/bundb/basic.go | 34 | ||||
-rw-r--r-- | internal/db/bundb/bundb.go | 5 | ||||
-rw-r--r-- | internal/db/bundb/headerfilter.go | 207 | ||||
-rw-r--r-- | internal/db/bundb/headerfilter_test.go | 125 | ||||
-rw-r--r-- | internal/db/bundb/migrations/20231212144715_add_header_filters.go | 54 | ||||
-rw-r--r-- | internal/db/db.go | 1 | ||||
-rw-r--r-- | internal/db/headerfilter.go | 73 |
8 files changed, 465 insertions, 38 deletions
diff --git a/internal/db/basic.go b/internal/db/basic.go index 7cd690aef..3a8e2af8d 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -25,10 +25,6 @@ type Basic interface { // For implementations that don't use tables, this can just return nil. CreateTable(ctx context.Context, i interface{}) error - // CreateAllTables creates *all* tables necessary for the running of GoToSocial. - // Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized. - CreateAllTables(ctx context.Context) error - // DropTable drops the table for the given interface. // For implementations that don't use tables, this can just return nil. DropTable(ctx context.Context, i interface{}) error diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index e68903efa..488f59ad5 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -22,7 +22,6 @@ import ( "errors" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/uptrace/bun" ) @@ -120,39 +119,6 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error { return err } -func (b *basicDB) CreateAllTables(ctx context.Context) error { - models := []interface{}{ - >smodel.Account{}, - >smodel.Application{}, - >smodel.Block{}, - >smodel.DomainBlock{}, - >smodel.EmailDomainBlock{}, - >smodel.Follow{}, - >smodel.FollowRequest{}, - >smodel.MediaAttachment{}, - >smodel.Mention{}, - >smodel.Status{}, - >smodel.StatusToEmoji{}, - >smodel.StatusFave{}, - >smodel.StatusBookmark{}, - >smodel.ThreadMute{}, - >smodel.Tag{}, - >smodel.User{}, - >smodel.Emoji{}, - >smodel.Instance{}, - >smodel.Notification{}, - >smodel.RouterSession{}, - >smodel.Token{}, - >smodel.Client{}, - } - for _, i := range models { - if err := b.CreateTable(ctx, i); err != nil { - return err - } - } - return nil -} - func (b *basicDB) DropTable(ctx context.Context, i interface{}) error { _, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx) return err diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index f7417cfeb..d9415eff4 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -67,6 +67,7 @@ type DBService struct { db.Basic db.Domain db.Emoji + db.HeaderFilter db.Instance db.List db.Marker @@ -193,6 +194,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { db: db, state: state, }, + HeaderFilter: &headerFilterDB{ + db: db, + state: state, + }, Instance: &instanceDB{ db: db, state: state, diff --git a/internal/db/bundb/headerfilter.go b/internal/db/bundb/headerfilter.go new file mode 100644 index 000000000..087b65c82 --- /dev/null +++ b/internal/db/bundb/headerfilter.go @@ -0,0 +1,207 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( + "context" + "net/http" + "time" + "unsafe" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/uptrace/bun" +) + +type headerFilterDB struct { + db *DB + state *state.State +} + +func (h *headerFilterDB) AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) { + return h.state.Caches.AllowHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { + return h.GetAllowHeaderFilters(ctx) + }) +} + +func (h *headerFilterDB) AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) { + return h.state.Caches.AllowHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { + return h.GetAllowHeaderFilters(ctx) + }) +} + +func (h *headerFilterDB) BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) { + return h.state.Caches.BlockHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { + return h.GetBlockHeaderFilters(ctx) + }) +} + +func (h *headerFilterDB) BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) { + return h.state.Caches.BlockHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { + return h.GetBlockHeaderFilters(ctx) + }) +} + +func (h *headerFilterDB) GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) { + filter := new(gtsmodel.HeaderFilterAllow) + if err := h.db.NewSelect(). + Model(filter). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx); err != nil { + return nil, err + } + return fromAllowFilter(filter), nil +} + +func (h *headerFilterDB) GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) { + filter := new(gtsmodel.HeaderFilterBlock) + if err := h.db.NewSelect(). + Model(filter). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx); err != nil { + return nil, err + } + return fromBlockFilter(filter), nil +} + +func (h *headerFilterDB) GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) { + var filters []*gtsmodel.HeaderFilterAllow + err := h.db.NewSelect(). + Model(&filters). + Scan(ctx, &filters) + return fromAllowFilters(filters), err +} + +func (h *headerFilterDB) GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) { + var filters []*gtsmodel.HeaderFilterBlock + err := h.db.NewSelect(). + Model(&filters). + Scan(ctx, &filters) + return fromBlockFilters(filters), err +} + +func (h *headerFilterDB) PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error { + if _, err := h.db.NewInsert(). + Model(toAllowFilter(filter)). + Exec(ctx); err != nil { + return err + } + h.state.Caches.AllowHeaderFilters.Clear() + return nil +} + +func (h *headerFilterDB) PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error { + if _, err := h.db.NewInsert(). + Model(toBlockFilter(filter)). + Exec(ctx); err != nil { + return err + } + h.state.Caches.BlockHeaderFilters.Clear() + return nil +} + +func (h *headerFilterDB) UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error { + filter.UpdatedAt = time.Now() + if len(cols) > 0 { + // If we're updating by column, + // ensure "updated_at" is included. + cols = append(cols, "updated_at") + } + if _, err := h.db.NewUpdate(). + Model(toAllowFilter(filter)). + Column(cols...). + Where("? = ?", bun.Ident("id"), filter.ID). + Exec(ctx); err != nil { + return err + } + h.state.Caches.AllowHeaderFilters.Clear() + return nil +} + +func (h *headerFilterDB) UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error { + filter.UpdatedAt = time.Now() + if len(cols) > 0 { + // If we're updating by column, + // ensure "updated_at" is included. + cols = append(cols, "updated_at") + } + if _, err := h.db.NewUpdate(). + Model(toBlockFilter(filter)). + Column(cols...). + Where("? = ?", bun.Ident("id"), filter.ID). + Exec(ctx); err != nil { + return err + } + h.state.Caches.BlockHeaderFilters.Clear() + return nil +} + +func (h *headerFilterDB) DeleteAllowHeaderFilter(ctx context.Context, id string) error { + if _, err := h.db.NewDelete(). + Table("header_filter_allows"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + h.state.Caches.AllowHeaderFilters.Clear() + return nil +} + +func (h *headerFilterDB) DeleteBlockHeaderFilter(ctx context.Context, id string) error { + if _, err := h.db.NewDelete(). + Table("header_filter_blocks"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + h.state.Caches.BlockHeaderFilters.Clear() + return nil +} + +// NOTE: +// all of the below unsafe cast functions +// are only possible because HeaderFilterAllow{}, +// HeaderFilterBlock{}, HeaderFilter{} while +// different types in source, have exactly the +// same size and layout in memory. the unsafe +// cast simply changes the type associated with +// that block of memory. + +func toAllowFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterAllow { + return (*gtsmodel.HeaderFilterAllow)(unsafe.Pointer(filter)) +} + +func toBlockFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterBlock { + return (*gtsmodel.HeaderFilterBlock)(unsafe.Pointer(filter)) +} + +func fromAllowFilter(filter *gtsmodel.HeaderFilterAllow) *gtsmodel.HeaderFilter { + return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter)) +} + +func fromBlockFilter(filter *gtsmodel.HeaderFilterBlock) *gtsmodel.HeaderFilter { + return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter)) +} + +func fromAllowFilters(filters []*gtsmodel.HeaderFilterAllow) []*gtsmodel.HeaderFilter { + return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters)) +} + +func fromBlockFilters(filters []*gtsmodel.HeaderFilterBlock) []*gtsmodel.HeaderFilter { + return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters)) +} diff --git a/internal/db/bundb/headerfilter_test.go b/internal/db/bundb/headerfilter_test.go new file mode 100644 index 000000000..d7e2b26ee --- /dev/null +++ b/internal/db/bundb/headerfilter_test.go @@ -0,0 +1,125 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type HeaderFilterTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *HeaderFilterTestSuite) TestAllowHeaderFilterGetPutUpdateDelete() { + suite.testHeaderFilterGetPutUpdateDelete( + suite.db.GetAllowHeaderFilter, + suite.db.GetAllowHeaderFilters, + suite.db.PutAllowHeaderFilter, + suite.db.UpdateAllowHeaderFilter, + suite.db.DeleteAllowHeaderFilter, + ) +} + +func (suite *HeaderFilterTestSuite) TestBlockHeaderFilterGetPutUpdateDelete() { + suite.testHeaderFilterGetPutUpdateDelete( + suite.db.GetBlockHeaderFilter, + suite.db.GetBlockHeaderFilters, + suite.db.PutBlockHeaderFilter, + suite.db.UpdateBlockHeaderFilter, + suite.db.DeleteBlockHeaderFilter, + ) +} + +func (suite *HeaderFilterTestSuite) testHeaderFilterGetPutUpdateDelete( + get func(context.Context, string) (*gtsmodel.HeaderFilter, error), + getAll func(context.Context) ([]*gtsmodel.HeaderFilter, error), + put func(context.Context, *gtsmodel.HeaderFilter) error, + update func(context.Context, *gtsmodel.HeaderFilter, ...string) error, + delete func(context.Context, string) error, +) { + t := suite.T() + + // Create new example header filter. + filter := gtsmodel.HeaderFilter{ + ID: "some unique id", + Header: "Http-Header-Key", + Regex: ".*", + AuthorID: "some unique author id", + } + + // Create new cancellable test context. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Insert the example header filter into db. + if err := put(ctx, &filter); err != nil { + t.Fatalf("error inserting header filter: %v", err) + } + + // Now fetch newly created filter. + check, err := get(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching header filter: %v", err) + } + + // Check all expected fields match. + suite.Equal(filter.ID, check.ID) + suite.Equal(filter.Header, check.Header) + suite.Equal(filter.Regex, check.Regex) + suite.Equal(filter.AuthorID, check.AuthorID) + + // Fetch all header filters. + all, err := getAll(ctx) + if err != nil { + t.Fatalf("error fetching header filters: %v", err) + } + + // Ensure contains example. + suite.Equal(len(all), 1) + suite.Equal(all[0].ID, filter.ID) + + // Update the header filter regex value. + check.Regex = "new regex value" + if err := update(ctx, check); err != nil { + t.Fatalf("error updating header filter: %v", err) + } + + // Ensure 'updated_at' was updated on check model. + suite.True(check.UpdatedAt.After(filter.UpdatedAt)) + + // Now delete the header filter from db. + if err := delete(ctx, filter.ID); err != nil { + t.Fatalf("error deleting header filter: %v", err) + } + + // Ensure we can't refetch it. + _, err = get(ctx, filter.ID) + if err != db.ErrNoEntries { + t.Fatalf("deleted header filter returned unexpected error: %v", err) + } +} + +func TestHeaderFilterTestSuite(t *testing.T) { + suite.Run(t, new(HeaderFilterTestSuite)) +} diff --git a/internal/db/bundb/migrations/20231212144715_add_header_filters.go b/internal/db/bundb/migrations/20231212144715_add_header_filters.go new file mode 100644 index 000000000..2d671bf46 --- /dev/null +++ b/internal/db/bundb/migrations/20231212144715_add_header_filters.go @@ -0,0 +1,54 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + for _, model := range []any{ + >smodel.HeaderFilterAllow{}, + >smodel.HeaderFilterBlock{}, + } { + _, err := db.NewCreateTable(). + IfNotExists(). + Model(model). + Exec(ctx) + if err != nil { + return err + } + } + + return nil + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/db.go b/internal/db/db.go index 2914d9b59..361687e94 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -30,6 +30,7 @@ type DB interface { Basic Domain Emoji + HeaderFilter Instance List Marker diff --git a/internal/db/headerfilter.go b/internal/db/headerfilter.go new file mode 100644 index 000000000..5fe8a5b17 --- /dev/null +++ b/internal/db/headerfilter.go @@ -0,0 +1,73 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( + "context" + "net/http" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type HeaderFilter interface { + // AllowHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached allow header filters. + // (Note: the actual matching code can be found under ./internal/headerfilter/ ). + AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) + + // AllowHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached allow header filters. + // (Note: the actual matching code can be found under ./internal/headerfilter/ ). + AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) + + // BlockHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached block header filters. + // (Note: the actual matching code can be found under ./internal/headerfilter/ ). + BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) + + // BlockHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached block header filters. + // (Note: the actual matching code can be found under ./internal/headerfilter/ ). + BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) + + // GetAllowHeaderFilter fetches the allow header filter with ID from the database. + GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) + + // GetBlockHeaderFilter fetches the block header filter with ID from the database. + GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) + + // GetAllowHeaderFilters fetches all allow header filters from the database. + GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) + + // GetBlockHeaderFilters fetches all block header filters from the database. + GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) + + // PutAllowHeaderFilter inserts the given allow header filter into the database. + PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error + + // PutBlockHeaderFilter inserts the given block header filter into the database. + PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error + + // UpdateAllowHeaderFilter updates the given allow header filter in the database, only updating given columns if provided. + UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error + + // UpdateBlockHeaderFilter updates the given block header filter in the database, only updating given columns if provided. + UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error + + // DeleteAllowHeaderFilter deletes the allow header filter with ID from the database. + DeleteAllowHeaderFilter(ctx context.Context, id string) error + + // DeleteBlockHeaderFilter deletes the block header filter with ID from the database. + DeleteBlockHeaderFilter(ctx context.Context, id string) error +} |