diff options
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/application.go | 38 | ||||
| -rw-r--r-- | internal/db/bundb/application.go | 97 | ||||
| -rw-r--r-- | internal/db/bundb/application_test.go | 128 | ||||
| -rw-r--r-- | internal/db/bundb/bundb.go | 5 | ||||
| -rw-r--r-- | internal/db/bundb/status.go | 26 | ||||
| -rw-r--r-- | internal/db/bundb/user.go | 167 | ||||
| -rw-r--r-- | internal/db/db.go | 1 | 
7 files changed, 377 insertions, 85 deletions
| diff --git a/internal/db/application.go b/internal/db/application.go new file mode 100644 index 000000000..34a857d3f --- /dev/null +++ b/internal/db/application.go @@ -0,0 +1,38 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package db + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type Application interface { +	// GetApplicationByID fetches the application from the database with corresponding ID value. +	GetApplicationByID(ctx context.Context, id string) (*gtsmodel.Application, error) + +	// GetApplicationByClientID fetches the application from the database with corresponding client_id value. +	GetApplicationByClientID(ctx context.Context, clientID string) (*gtsmodel.Application, error) + +	// PutApplication places the new application in the database, erroring on non-unique ID or client_id. +	PutApplication(ctx context.Context, app *gtsmodel.Application) error + +	// DeleteApplicationByClientID deletes the application with corresponding client_id value from the database. +	DeleteApplicationByClientID(ctx context.Context, clientID string) error +} diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go new file mode 100644 index 000000000..b53d2c0b0 --- /dev/null +++ b/internal/db/bundb/application.go @@ -0,0 +1,97 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/uptrace/bun" +) + +type applicationDB struct { +	db    *WrappedDB +	state *state.State +} + +func (a *applicationDB) GetApplicationByID(ctx context.Context, id string) (*gtsmodel.Application, error) { +	return a.getApplication( +		ctx, +		"ID", +		func(app *gtsmodel.Application) error { +			return a.db.NewSelect().Model(app).Where("? = ?", bun.Ident("id"), id).Scan(ctx) +		}, +		id, +	) +} + +func (a *applicationDB) GetApplicationByClientID(ctx context.Context, clientID string) (*gtsmodel.Application, error) { +	return a.getApplication( +		ctx, +		"ClientID", +		func(app *gtsmodel.Application) error { +			return a.db.NewSelect().Model(app).Where("? = ?", bun.Ident("client_id"), clientID).Scan(ctx) +		}, +		clientID, +	) +} + +func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Application) error, keyParts ...any) (*gtsmodel.Application, error) { +	return a.state.Caches.GTS.Application().Load(lookup, func() (*gtsmodel.Application, error) { +		var app gtsmodel.Application + +		// Not cached! Perform database query. +		if err := dbQuery(&app); err != nil { +			return nil, a.db.ProcessError(err) +		} + +		return &app, nil +	}, keyParts...) +} + +func (a *applicationDB) PutApplication(ctx context.Context, app *gtsmodel.Application) error { +	return a.state.Caches.GTS.Application().Store(app, func() error { +		_, err := a.db.NewInsert().Model(app).Exec(ctx) +		return a.db.ProcessError(err) +	}) +} + +func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientID string) error { +	// Attempt to delete application. +	if _, err := a.db.NewDelete(). +		Table("applications"). +		Where("? = ?", bun.Ident("client_id"), clientID). +		Exec(ctx); err != nil { +		return a.db.ProcessError(err) +	} + +	// NOTE about further side effects: +	// +	// We don't need to handle updating any statuses or users +	// (both of which may contain refs to applications), as +	// DeleteApplication__() is only ever called during an +	// account deletion, which handles deletion of the user +	// and all their statuses already. +	// + +	// Clear application from the cache. +	a.state.Caches.GTS.Application().Invalidate("ClientID", clientID) + +	return nil +} diff --git a/internal/db/bundb/application_test.go b/internal/db/bundb/application_test.go new file mode 100644 index 000000000..d2ab05ebd --- /dev/null +++ b/internal/db/bundb/application_test.go @@ -0,0 +1,128 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package bundb_test + +import ( +	"context" +	"errors" +	"reflect" +	"testing" +	"time" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type ApplicationTestSuite struct { +	BunDBStandardTestSuite +} + +func (suite *ApplicationTestSuite) TestGetApplicationBy() { +	t := suite.T() + +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	// Sentinel error to mark avoiding a test case. +	sentinelErr := errors.New("sentinel") + +	// isEqual checks if 2 application models are equal. +	isEqual := func(a1, a2 gtsmodel.Application) bool { +		// Clear database-set fields. +		a1.CreatedAt = time.Time{} +		a2.CreatedAt = time.Time{} +		a1.UpdatedAt = time.Time{} +		a2.UpdatedAt = time.Time{} + +		return reflect.DeepEqual(a1, a2) +	} + +	for _, app := range suite.testApplications { +		for lookup, dbfunc := range map[string]func() (*gtsmodel.Application, error){ +			"id": func() (*gtsmodel.Application, error) { +				return suite.db.GetApplicationByID(ctx, app.ID) +			}, + +			"client_id": func() (*gtsmodel.Application, error) { +				return suite.db.GetApplicationByClientID(ctx, app.ClientID) +			}, +		} { +			// Clear database caches. +			suite.state.Caches.Init() + +			t.Logf("checking database lookup %q", lookup) + +			// Perform database function. +			checkApp, err := dbfunc() +			if err != nil { +				if err == sentinelErr { +					continue +				} + +				t.Errorf("error encountered for database lookup %q: %v", lookup, err) +				continue +			} + +			// Check received application data. +			if !isEqual(*checkApp, *app) { +				t.Errorf("application does not contain expected data: %+v", checkApp) +				continue +			} +		} +	} +} + +func (suite *ApplicationTestSuite) TestDeleteApplicationBy() { +	t := suite.T() + +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	for _, app := range suite.testApplications { +		for lookup, dbfunc := range map[string]func() error{ +			"client_id": func() error { +				return suite.db.DeleteApplicationByClientID(ctx, app.ClientID) +			}, +		} { +			// Clear database caches. +			suite.state.Caches.Init() + +			t.Logf("checking database lookup %q", lookup) + +			// Perform database function. +			err := dbfunc() +			if err != nil { +				t.Errorf("error encountered for database lookup %q: %v", lookup, err) +				continue +			} + +			// Ensure this application has been deleted and cache cleared. +			if _, err := suite.db.GetApplicationByID(ctx, app.ID); err != db.ErrNoEntries { +				t.Errorf("application does not appear to have been deleted %q: %v", lookup, err) +				continue +			} +		} +	} +} + +func TestApplicationTestSuite(t *testing.T) { +	suite.Run(t, new(ApplicationTestSuite)) +} diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 8387bb8d1..26b31ff28 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -60,6 +60,7 @@ var registerTables = []interface{}{  type DBService struct {  	db.Account  	db.Admin +	db.Application  	db.Basic  	db.Domain  	db.Emoji @@ -168,6 +169,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  			db:    db,  			state: state,  		}, +		Application: &applicationDB{ +			db:    db, +			state: state, +		},  		Basic: &basicDB{  			db: db,  		}, diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index c6091e2c9..311732299 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -37,19 +37,12 @@ type statusDB struct {  	state *state.State  } -func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { -	return s.db. -		NewSelect(). -		Model(status). -		Relation("CreatedWithApplication") -} -  func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error) {  	return s.getStatus(  		ctx,  		"ID",  		func(status *gtsmodel.Status) error { -			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx) +			return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx)  		},  		id,  	) @@ -78,7 +71,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St  		ctx,  		"URI",  		func(status *gtsmodel.Status) error { -			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx) +			return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx)  		},  		uri,  	) @@ -89,7 +82,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St  		ctx,  		"URL",  		func(status *gtsmodel.Status) error { -			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx) +			return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx)  		},  		url,  	) @@ -100,7 +93,7 @@ func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccou  		ctx,  		"BoostOfID.AccountID",  		func(status *gtsmodel.Status) error { -			return s.newStatusQ(status). +			return s.db.NewSelect().Model(status).  				Where("status.boost_of_id = ?", boostOfID).  				Where("status.account_id = ?", byAccountID). @@ -264,6 +257,17 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)  		}  	} +	if status.CreatedWithApplicationID != "" && status.CreatedWithApplication == nil { +		// Populate the status' expected CreatedWithApplication (not always set). +		status.CreatedWithApplication, err = s.state.DB.GetApplicationByID( +			ctx, // these are already barebones +			status.CreatedWithApplicationID, +		) +		if err != nil { +			errs.Appendf("error populating status application: %w", err) +		} +	} +  	return errs.Combine()  } diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index 4b38d48fa..9df05596e 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -24,6 +24,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/uptrace/bun" @@ -35,107 +36,125 @@ type userDB struct {  }  func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error) { -	return u.state.Caches.GTS.User().Load("ID", func() (*gtsmodel.User, error) { -		var user gtsmodel.User - -		q := u.db. -			NewSelect(). -			Model(&user). -			Relation("Account"). -			Where("? = ?", bun.Ident("user.id"), id) +	return u.getUser( +		ctx, +		"ID", +		func(user *gtsmodel.User) error { +			return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("id"), id).Scan(ctx) +		}, +		id, +	) +} -		if err := q.Scan(ctx); err != nil { -			return nil, u.db.ProcessError(err) +func (u *userDB) GetUsersByIDs(ctx context.Context, ids []string) ([]*gtsmodel.User, error) { +	var ( +		users = make([]*gtsmodel.User, 0, len(ids)) + +		// Collect errors instead of +		// returning early on any. +		errs gtserror.MultiError +	) + +	for _, id := range ids { +		// Attempt to fetch user from DB. +		user, err := u.GetUserByID(ctx, id) +		if err != nil { +			errs.Appendf("error getting user %s: %w", id, err) +			continue  		} -		return &user, nil -	}, id) +		// Append user to return slice. +		users = append(users, user) +	} + +	return users, errs.Combine()  }  func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) { -	return u.state.Caches.GTS.User().Load("AccountID", func() (*gtsmodel.User, error) { -		var user gtsmodel.User - -		q := u.db. -			NewSelect(). -			Model(&user). -			Relation("Account"). -			Where("? = ?", bun.Ident("user.account_id"), accountID) - -		if err := q.Scan(ctx); err != nil { -			return nil, u.db.ProcessError(err) -		} - -		return &user, nil -	}, accountID) +	return u.getUser( +		ctx, +		"AccountID", +		func(user *gtsmodel.User) error { +			return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("account_id"), accountID).Scan(ctx) +		}, +		accountID, +	)  } -func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, error) { -	return u.state.Caches.GTS.User().Load("Email", func() (*gtsmodel.User, error) { -		var user gtsmodel.User - -		q := u.db. -			NewSelect(). -			Model(&user). -			Relation("Account"). -			Where("? = ?", bun.Ident("user.email"), emailAddress) - -		if err := q.Scan(ctx); err != nil { -			return nil, u.db.ProcessError(err) -		} - -		return &user, nil -	}, emailAddress) +func (u *userDB) GetUserByEmailAddress(ctx context.Context, email string) (*gtsmodel.User, error) { +	return u.getUser( +		ctx, +		"Email", +		func(user *gtsmodel.User) error { +			return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("email"), email).Scan(ctx) +		}, +		email, +	)  }  func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error) { -	return u.state.Caches.GTS.User().Load("ExternalID", func() (*gtsmodel.User, error) { -		var user gtsmodel.User - -		q := u.db. -			NewSelect(). -			Model(&user). -			Relation("Account"). -			Where("? = ?", bun.Ident("user.external_id"), id) - -		if err := q.Scan(ctx); err != nil { -			return nil, u.db.ProcessError(err) -		} +	return u.getUser( +		ctx, +		"ExternalID", +		func(user *gtsmodel.User) error { +			return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("external_id"), id).Scan(ctx) +		}, +		id, +	) +} -		return &user, nil -	}, id) +func (u *userDB) GetUserByConfirmationToken(ctx context.Context, token string) (*gtsmodel.User, error) { +	return u.getUser( +		ctx, +		"ConfirmationToken", +		func(user *gtsmodel.User) error { +			return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("confirmation_token"), token).Scan(ctx) +		}, +		token, +	)  } -func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, error) { -	return u.state.Caches.GTS.User().Load("ConfirmationToken", func() (*gtsmodel.User, error) { +func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmodel.User) error, keyParts ...any) (*gtsmodel.User, error) { +	// Fetch user from database cache with loader callback. +	user, err := u.state.Caches.GTS.User().Load(lookup, func() (*gtsmodel.User, error) {  		var user gtsmodel.User -		q := u.db. -			NewSelect(). -			Model(&user). -			Relation("Account"). -			Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken) - -		if err := q.Scan(ctx); err != nil { +		// Not cached! perform database query. +		if err := dbQuery(&user); err != nil {  			return nil, u.db.ProcessError(err)  		}  		return &user, nil -	}, confirmationToken) +	}, keyParts...) +	if err != nil { +		return nil, err +	} + +	// Fetch the related account model for this user. +	user.Account, err = u.state.DB.GetAccountByID( +		gtscontext.SetBarebones(ctx), +		user.AccountID, +	) +	if err != nil { +		return nil, gtserror.Newf("error populating user account: %w", err) +	} + +	return user, nil  }  func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) { -	var users []*gtsmodel.User -	q := u.db. -		NewSelect(). -		Model(&users). -		Relation("Account") +	var userIDs []string -	if err := q.Scan(ctx); err != nil { +	// Scan all user IDs into slice. +	if err := u.db.NewSelect(). +		Table("users"). +		Column("id"). +		Scan(ctx, &userIDs); err != nil {  		return nil, u.db.ProcessError(err)  	} -	return users, nil +	// Transform user IDs into user slice. +	return u.GetUsersByIDs(ctx, userIDs)  }  func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error { diff --git a/internal/db/db.go b/internal/db/db.go index 7c00050ff..567551c73 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -26,6 +26,7 @@ const (  type DB interface {  	Account  	Admin +	Application  	Basic  	Domain  	Emoji | 
