diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/basic.go | 4 | ||||
| -rw-r--r-- | internal/db/bundb/basic.go | 34 | ||||
| -rw-r--r-- | internal/db/bundb/bundb.go | 5 | ||||
| -rw-r--r-- | internal/db/bundb/headerfilter.go | 207 | ||||
| -rw-r--r-- | internal/db/bundb/headerfilter_test.go | 125 | ||||
| -rw-r--r-- | internal/db/bundb/migrations/20231212144715_add_header_filters.go | 54 | ||||
| -rw-r--r-- | internal/db/db.go | 1 | ||||
| -rw-r--r-- | internal/db/headerfilter.go | 73 | 
8 files changed, 465 insertions, 38 deletions
| diff --git a/internal/db/basic.go b/internal/db/basic.go index 7cd690aef..3a8e2af8d 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -25,10 +25,6 @@ type Basic interface {  	// For implementations that don't use tables, this can just return nil.  	CreateTable(ctx context.Context, i interface{}) error -	// CreateAllTables creates *all* tables necessary for the running of GoToSocial. -	// Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized. -	CreateAllTables(ctx context.Context) error -  	// DropTable drops the table for the given interface.  	// For implementations that don't use tables, this can just return nil.  	DropTable(ctx context.Context, i interface{}) error diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index e68903efa..488f59ad5 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -22,7 +22,6 @@ import (  	"errors"  	"github.com/superseriousbusiness/gotosocial/internal/db" -	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/uptrace/bun"  ) @@ -120,39 +119,6 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error {  	return err  } -func (b *basicDB) CreateAllTables(ctx context.Context) error { -	models := []interface{}{ -		>smodel.Account{}, -		>smodel.Application{}, -		>smodel.Block{}, -		>smodel.DomainBlock{}, -		>smodel.EmailDomainBlock{}, -		>smodel.Follow{}, -		>smodel.FollowRequest{}, -		>smodel.MediaAttachment{}, -		>smodel.Mention{}, -		>smodel.Status{}, -		>smodel.StatusToEmoji{}, -		>smodel.StatusFave{}, -		>smodel.StatusBookmark{}, -		>smodel.ThreadMute{}, -		>smodel.Tag{}, -		>smodel.User{}, -		>smodel.Emoji{}, -		>smodel.Instance{}, -		>smodel.Notification{}, -		>smodel.RouterSession{}, -		>smodel.Token{}, -		>smodel.Client{}, -	} -	for _, i := range models { -		if err := b.CreateTable(ctx, i); err != nil { -			return err -		} -	} -	return nil -} -  func (b *basicDB) DropTable(ctx context.Context, i interface{}) error {  	_, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx)  	return err diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index f7417cfeb..d9415eff4 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -67,6 +67,7 @@ type DBService struct {  	db.Basic  	db.Domain  	db.Emoji +	db.HeaderFilter  	db.Instance  	db.List  	db.Marker @@ -193,6 +194,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  			db:    db,  			state: state,  		}, +		HeaderFilter: &headerFilterDB{ +			db:    db, +			state: state, +		},  		Instance: &instanceDB{  			db:    db,  			state: state, diff --git a/internal/db/bundb/headerfilter.go b/internal/db/bundb/headerfilter.go new file mode 100644 index 000000000..087b65c82 --- /dev/null +++ b/internal/db/bundb/headerfilter.go @@ -0,0 +1,207 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( +	"context" +	"net/http" +	"time" +	"unsafe" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/uptrace/bun" +) + +type headerFilterDB struct { +	db    *DB +	state *state.State +} + +func (h *headerFilterDB) AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) { +	return h.state.Caches.AllowHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { +		return h.GetAllowHeaderFilters(ctx) +	}) +} + +func (h *headerFilterDB) AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) { +	return h.state.Caches.AllowHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { +		return h.GetAllowHeaderFilters(ctx) +	}) +} + +func (h *headerFilterDB) BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) { +	return h.state.Caches.BlockHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { +		return h.GetBlockHeaderFilters(ctx) +	}) +} + +func (h *headerFilterDB) BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) { +	return h.state.Caches.BlockHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) { +		return h.GetBlockHeaderFilters(ctx) +	}) +} + +func (h *headerFilterDB) GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) { +	filter := new(gtsmodel.HeaderFilterAllow) +	if err := h.db.NewSelect(). +		Model(filter). +		Where("? = ?", bun.Ident("id"), id). +		Scan(ctx); err != nil { +		return nil, err +	} +	return fromAllowFilter(filter), nil +} + +func (h *headerFilterDB) GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) { +	filter := new(gtsmodel.HeaderFilterBlock) +	if err := h.db.NewSelect(). +		Model(filter). +		Where("? = ?", bun.Ident("id"), id). +		Scan(ctx); err != nil { +		return nil, err +	} +	return fromBlockFilter(filter), nil +} + +func (h *headerFilterDB) GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) { +	var filters []*gtsmodel.HeaderFilterAllow +	err := h.db.NewSelect(). +		Model(&filters). +		Scan(ctx, &filters) +	return fromAllowFilters(filters), err +} + +func (h *headerFilterDB) GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) { +	var filters []*gtsmodel.HeaderFilterBlock +	err := h.db.NewSelect(). +		Model(&filters). +		Scan(ctx, &filters) +	return fromBlockFilters(filters), err +} + +func (h *headerFilterDB) PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error { +	if _, err := h.db.NewInsert(). +		Model(toAllowFilter(filter)). +		Exec(ctx); err != nil { +		return err +	} +	h.state.Caches.AllowHeaderFilters.Clear() +	return nil +} + +func (h *headerFilterDB) PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error { +	if _, err := h.db.NewInsert(). +		Model(toBlockFilter(filter)). +		Exec(ctx); err != nil { +		return err +	} +	h.state.Caches.BlockHeaderFilters.Clear() +	return nil +} + +func (h *headerFilterDB) UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error { +	filter.UpdatedAt = time.Now() +	if len(cols) > 0 { +		// If we're updating by column, +		// ensure "updated_at" is included. +		cols = append(cols, "updated_at") +	} +	if _, err := h.db.NewUpdate(). +		Model(toAllowFilter(filter)). +		Column(cols...). +		Where("? = ?", bun.Ident("id"), filter.ID). +		Exec(ctx); err != nil { +		return err +	} +	h.state.Caches.AllowHeaderFilters.Clear() +	return nil +} + +func (h *headerFilterDB) UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error { +	filter.UpdatedAt = time.Now() +	if len(cols) > 0 { +		// If we're updating by column, +		// ensure "updated_at" is included. +		cols = append(cols, "updated_at") +	} +	if _, err := h.db.NewUpdate(). +		Model(toBlockFilter(filter)). +		Column(cols...). +		Where("? = ?", bun.Ident("id"), filter.ID). +		Exec(ctx); err != nil { +		return err +	} +	h.state.Caches.BlockHeaderFilters.Clear() +	return nil +} + +func (h *headerFilterDB) DeleteAllowHeaderFilter(ctx context.Context, id string) error { +	if _, err := h.db.NewDelete(). +		Table("header_filter_allows"). +		Where("? = ?", bun.Ident("id"), id). +		Exec(ctx); err != nil { +		return err +	} +	h.state.Caches.AllowHeaderFilters.Clear() +	return nil +} + +func (h *headerFilterDB) DeleteBlockHeaderFilter(ctx context.Context, id string) error { +	if _, err := h.db.NewDelete(). +		Table("header_filter_blocks"). +		Where("? = ?", bun.Ident("id"), id). +		Exec(ctx); err != nil { +		return err +	} +	h.state.Caches.BlockHeaderFilters.Clear() +	return nil +} + +// NOTE: +// all of the below unsafe cast functions +// are only possible because HeaderFilterAllow{}, +// HeaderFilterBlock{}, HeaderFilter{} while +// different types in source, have exactly the +// same size and layout in memory. the unsafe +// cast simply changes the type associated with +// that block of memory. + +func toAllowFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterAllow { +	return (*gtsmodel.HeaderFilterAllow)(unsafe.Pointer(filter)) +} + +func toBlockFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterBlock { +	return (*gtsmodel.HeaderFilterBlock)(unsafe.Pointer(filter)) +} + +func fromAllowFilter(filter *gtsmodel.HeaderFilterAllow) *gtsmodel.HeaderFilter { +	return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter)) +} + +func fromBlockFilter(filter *gtsmodel.HeaderFilterBlock) *gtsmodel.HeaderFilter { +	return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter)) +} + +func fromAllowFilters(filters []*gtsmodel.HeaderFilterAllow) []*gtsmodel.HeaderFilter { +	return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters)) +} + +func fromBlockFilters(filters []*gtsmodel.HeaderFilterBlock) []*gtsmodel.HeaderFilter { +	return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters)) +} diff --git a/internal/db/bundb/headerfilter_test.go b/internal/db/bundb/headerfilter_test.go new file mode 100644 index 000000000..d7e2b26ee --- /dev/null +++ b/internal/db/bundb/headerfilter_test.go @@ -0,0 +1,125 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( +	"context" +	"testing" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type HeaderFilterTestSuite struct { +	BunDBStandardTestSuite +} + +func (suite *HeaderFilterTestSuite) TestAllowHeaderFilterGetPutUpdateDelete() { +	suite.testHeaderFilterGetPutUpdateDelete( +		suite.db.GetAllowHeaderFilter, +		suite.db.GetAllowHeaderFilters, +		suite.db.PutAllowHeaderFilter, +		suite.db.UpdateAllowHeaderFilter, +		suite.db.DeleteAllowHeaderFilter, +	) +} + +func (suite *HeaderFilterTestSuite) TestBlockHeaderFilterGetPutUpdateDelete() { +	suite.testHeaderFilterGetPutUpdateDelete( +		suite.db.GetBlockHeaderFilter, +		suite.db.GetBlockHeaderFilters, +		suite.db.PutBlockHeaderFilter, +		suite.db.UpdateBlockHeaderFilter, +		suite.db.DeleteBlockHeaderFilter, +	) +} + +func (suite *HeaderFilterTestSuite) testHeaderFilterGetPutUpdateDelete( +	get func(context.Context, string) (*gtsmodel.HeaderFilter, error), +	getAll func(context.Context) ([]*gtsmodel.HeaderFilter, error), +	put func(context.Context, *gtsmodel.HeaderFilter) error, +	update func(context.Context, *gtsmodel.HeaderFilter, ...string) error, +	delete func(context.Context, string) error, +) { +	t := suite.T() + +	// Create new example header filter. +	filter := gtsmodel.HeaderFilter{ +		ID:       "some unique id", +		Header:   "Http-Header-Key", +		Regex:    ".*", +		AuthorID: "some unique author id", +	} + +	// Create new cancellable test context. +	ctx := context.Background() +	ctx, cncl := context.WithCancel(ctx) +	defer cncl() + +	// Insert the example header filter into db. +	if err := put(ctx, &filter); err != nil { +		t.Fatalf("error inserting header filter: %v", err) +	} + +	// Now fetch newly created filter. +	check, err := get(ctx, filter.ID) +	if err != nil { +		t.Fatalf("error fetching header filter: %v", err) +	} + +	// Check all expected fields match. +	suite.Equal(filter.ID, check.ID) +	suite.Equal(filter.Header, check.Header) +	suite.Equal(filter.Regex, check.Regex) +	suite.Equal(filter.AuthorID, check.AuthorID) + +	// Fetch all header filters. +	all, err := getAll(ctx) +	if err != nil { +		t.Fatalf("error fetching header filters: %v", err) +	} + +	// Ensure contains example. +	suite.Equal(len(all), 1) +	suite.Equal(all[0].ID, filter.ID) + +	// Update the header filter regex value. +	check.Regex = "new regex value" +	if err := update(ctx, check); err != nil { +		t.Fatalf("error updating header filter: %v", err) +	} + +	// Ensure 'updated_at' was updated on check model. +	suite.True(check.UpdatedAt.After(filter.UpdatedAt)) + +	// Now delete the header filter from db. +	if err := delete(ctx, filter.ID); err != nil { +		t.Fatalf("error deleting header filter: %v", err) +	} + +	// Ensure we can't refetch it. +	_, err = get(ctx, filter.ID) +	if err != db.ErrNoEntries { +		t.Fatalf("deleted header filter returned unexpected error: %v", err) +	} +} + +func TestHeaderFilterTestSuite(t *testing.T) { +	suite.Run(t, new(HeaderFilterTestSuite)) +} diff --git a/internal/db/bundb/migrations/20231212144715_add_header_filters.go b/internal/db/bundb/migrations/20231212144715_add_header_filters.go new file mode 100644 index 000000000..2d671bf46 --- /dev/null +++ b/internal/db/bundb/migrations/20231212144715_add_header_filters.go @@ -0,0 +1,54 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package migrations + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun" +) + +func init() { +	up := func(ctx context.Context, db *bun.DB) error { +		for _, model := range []any{ +			>smodel.HeaderFilterAllow{}, +			>smodel.HeaderFilterBlock{}, +		} { +			_, err := db.NewCreateTable(). +				IfNotExists(). +				Model(model). +				Exec(ctx) +			if err != nil { +				return err +			} +		} + +		return nil +	} + +	down := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			return nil +		}) +	} + +	if err := Migrations.Register(up, down); err != nil { +		panic(err) +	} +} diff --git a/internal/db/db.go b/internal/db/db.go index 2914d9b59..361687e94 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -30,6 +30,7 @@ type DB interface {  	Basic  	Domain  	Emoji +	HeaderFilter  	Instance  	List  	Marker diff --git a/internal/db/headerfilter.go b/internal/db/headerfilter.go new file mode 100644 index 000000000..5fe8a5b17 --- /dev/null +++ b/internal/db/headerfilter.go @@ -0,0 +1,73 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( +	"context" +	"net/http" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type HeaderFilter interface { +	// AllowHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached allow header filters. +	// (Note: the actual matching code can be found under ./internal/headerfilter/ ). +	AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) + +	// AllowHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached allow header filters. +	// (Note: the actual matching code can be found under ./internal/headerfilter/ ). +	AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) + +	// BlockHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached block header filters. +	// (Note: the actual matching code can be found under ./internal/headerfilter/ ). +	BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) + +	// BlockHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached block header filters. +	// (Note: the actual matching code can be found under ./internal/headerfilter/ ). +	BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) + +	// GetAllowHeaderFilter fetches the allow header filter with ID from the database. +	GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) + +	// GetBlockHeaderFilter fetches the block header filter with ID from the database. +	GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) + +	// GetAllowHeaderFilters fetches all allow header filters from the database. +	GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) + +	// GetBlockHeaderFilters fetches all block header filters from the database. +	GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) + +	// PutAllowHeaderFilter inserts the given allow header filter into the database. +	PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error + +	// PutBlockHeaderFilter inserts the given block header filter into the database. +	PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error + +	// UpdateAllowHeaderFilter updates the given allow header filter in the database, only updating given columns if provided. +	UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error + +	// UpdateBlockHeaderFilter updates the given block header filter in the database, only updating given columns if provided. +	UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error + +	// DeleteAllowHeaderFilter deletes the allow header filter with ID from the database. +	DeleteAllowHeaderFilter(ctx context.Context, id string) error + +	// DeleteBlockHeaderFilter deletes the block header filter with ID from the database. +	DeleteBlockHeaderFilter(ctx context.Context, id string) error +} | 
