diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/client/auth/authorize.go | 8 | ||||
| -rw-r--r-- | internal/api/client/auth/authorize_test.go | 10 | ||||
| -rw-r--r-- | internal/api/client/auth/callback.go | 3 | ||||
| -rw-r--r-- | internal/api/client/auth/signin.go | 6 | ||||
| -rw-r--r-- | internal/api/security/tokencheck.go | 23 | ||||
| -rw-r--r-- | internal/cache/user.go | 141 | ||||
| -rw-r--r-- | internal/db/bundb/admin.go | 5 | ||||
| -rw-r--r-- | internal/db/bundb/bundb.go | 13 | ||||
| -rw-r--r-- | internal/db/bundb/user.go | 151 | ||||
| -rw-r--r-- | internal/db/bundb/user_test.go | 73 | ||||
| -rw-r--r-- | internal/db/db.go | 1 | ||||
| -rw-r--r-- | internal/db/user.go | 42 | ||||
| -rw-r--r-- | internal/processing/account/delete.go | 19 | ||||
| -rw-r--r-- | internal/processing/fromclientapi.go | 5 | ||||
| -rw-r--r-- | internal/processing/fromfederator_test.go | 2 | ||||
| -rw-r--r-- | internal/processing/instance.go | 4 | ||||
| -rw-r--r-- | internal/processing/streaming/authorize.go | 4 | ||||
| -rw-r--r-- | internal/processing/user/emailconfirm.go | 4 | ||||
| -rw-r--r-- | internal/typeutils/internaltofrontend_test.go | 8 | ||||
| -rw-r--r-- | internal/visibility/statusvisible.go | 8 | 
20 files changed, 476 insertions, 54 deletions
diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go index 83cddd9b5..b345f9b01 100644 --- a/internal/api/client/auth/authorize.go +++ b/internal/api/client/auth/authorize.go @@ -94,8 +94,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {  		return  	} -	user := >smodel.User{} -	if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { +	user, err := m.db.GetUserByID(c.Request.Context(), userID) +	if err != nil {  		m.clearSession(s)  		safe := fmt.Sprintf("user with id %s could not be retrieved", userID)  		var errWithCode gtserror.WithCode @@ -213,8 +213,8 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {  		return  	} -	user := >smodel.User{} -	if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { +	user, err := m.db.GetUserByID(c.Request.Context(), userID) +	if err != nil {  		m.clearSession(s)  		safe := fmt.Sprintf("user with id %s could not be retrieved", userID)  		var errWithCode gtserror.WithCode diff --git a/internal/api/client/auth/authorize_test.go b/internal/api/client/auth/authorize_test.go index eab893416..fcc4b8caa 100644 --- a/internal/api/client/auth/authorize_test.go +++ b/internal/api/client/auth/authorize_test.go @@ -76,8 +76,11 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {  	doTest := func(testCase authorizeHandlerTestCase) {  		ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "") -		user := suite.testUsers["unconfirmed_account"] -		account := suite.testAccounts["unconfirmed_account"] +		user := >smodel.User{} +		account := >smodel.Account{} + +		*user = *suite.testUsers["unconfirmed_account"] +		*account = *suite.testAccounts["unconfirmed_account"]  		testSession := sessions.Default(ctx)  		testSession.Set(sessionUserID, user.ID) @@ -91,8 +94,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {  		testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt)  		updatingColumns = append(updatingColumns, "updated_at") -		user.UpdatedAt = time.Now() -		err := suite.db.UpdateByPrimaryKey(context.Background(), user, updatingColumns...) +		_, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...)  		suite.NoError(err)  		_, err = suite.db.UpdateAccount(context.Background(), account)  		suite.NoError(err) diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go index 96a73a52f..daee2ae31 100644 --- a/internal/api/client/auth/callback.go +++ b/internal/api/client/auth/callback.go @@ -134,8 +134,7 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i  	// see if we already have a user for this email address  	// if so, we don't need to continue + create one -	user := >smodel.User{} -	err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user) +	user, err := m.db.GetUserByEmailAddress(ctx, claims.Email)  	if err == nil {  		return user, nil  	} diff --git a/internal/api/client/auth/signin.go b/internal/api/client/auth/signin.go index 58f3fad7e..06b601b10 100644 --- a/internal/api/client/auth/signin.go +++ b/internal/api/client/auth/signin.go @@ -28,9 +28,7 @@ import (  	"github.com/gin-gonic/gin"  	"github.com/superseriousbusiness/gotosocial/internal/api"  	"github.com/superseriousbusiness/gotosocial/internal/config" -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror" -	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"golang.org/x/crypto/bcrypt"  ) @@ -119,8 +117,8 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st  		return incorrectPassword(err)  	} -	user := >smodel.User{} -	if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, user); err != nil { +	user, err := m.db.GetUserByEmailAddress(ctx, email) +	if err != nil {  		err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)  		return incorrectPassword(err)  	} diff --git a/internal/api/security/tokencheck.go b/internal/api/security/tokencheck.go index 3df7ee943..9f2b7f36e 100644 --- a/internal/api/security/tokencheck.go +++ b/internal/api/security/tokencheck.go @@ -52,8 +52,8 @@ func (m *Module) TokenCheck(c *gin.Context) {  		log.Tracef("authenticated user %s with bearer token, scope is %s", userID, ti.GetScope())  		// fetch user for this token -		user := >smodel.User{} -		if err := m.db.GetByID(ctx, userID, user); err != nil { +		user, err := m.db.GetUserByID(ctx, userID) +		if err != nil {  			if err != db.ErrNoEntries {  				log.Errorf("database error looking for user with id %s: %s", userID, err)  				return @@ -80,22 +80,25 @@ func (m *Module) TokenCheck(c *gin.Context) {  		c.Set(oauth.SessionAuthorizedUser, user)  		// fetch account for this token -		acct, err := m.db.GetAccountByID(ctx, user.AccountID) -		if err != nil { -			if err != db.ErrNoEntries { -				log.Errorf("database error looking for account with id %s: %s", user.AccountID, err) +		if user.Account == nil { +			acct, err := m.db.GetAccountByID(ctx, user.AccountID) +			if err != nil { +				if err != db.ErrNoEntries { +					log.Errorf("database error looking for account with id %s: %s", user.AccountID, err) +					return +				} +				log.Warnf("no account found for userID %s", userID)  				return  			} -			log.Warnf("no account found for userID %s", userID) -			return +			user.Account = acct  		} -		if !acct.SuspendedAt.IsZero() { +		if !user.Account.SuspendedAt.IsZero() {  			log.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID)  			return  		} -		c.Set(oauth.SessionAuthorizedAccount, acct) +		c.Set(oauth.SessionAuthorizedAccount, user.Account)  	}  	// check for application token diff --git a/internal/cache/user.go b/internal/cache/user.go new file mode 100644 index 000000000..23bf0b7e9 --- /dev/null +++ b/internal/cache/user.go @@ -0,0 +1,141 @@ +/* +   GoToSocial +   Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package cache + +import ( +	"time" + +	"codeberg.org/gruf/go-cache/v2" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// UserCache is a cache wrapper to provide lookups for gtsmodel.User +type UserCache struct { +	cache cache.LookupCache[string, string, *gtsmodel.User] +} + +// NewUserCache returns a new instantiated UserCache object +func NewUserCache() *UserCache { +	c := &UserCache{} +	c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.User]{ +		RegisterLookups: func(lm *cache.LookupMap[string, string]) { +			lm.RegisterLookup("accountid") +			lm.RegisterLookup("email") +			lm.RegisterLookup("unconfirmedemail") +			lm.RegisterLookup("confirmationtoken") +		}, + +		AddLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) { +			lm.Set("accountid", user.AccountID, user.ID) +			if email := user.Email; email != "" { +				lm.Set("email", email, user.ID) +			} +			if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" { +				lm.Set("unconfirmedemail", unconfirmedEmail, user.ID) +			} +			if confirmationToken := user.ConfirmationToken; confirmationToken != "" { +				lm.Set("confirmationtoken", confirmationToken, user.ID) +			} +		}, + +		DeleteLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) { +			lm.Delete("accountid", user.AccountID) +			if email := user.Email; email != "" { +				lm.Delete("email", email) +			} +			if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" { +				lm.Delete("unconfirmedemail", unconfirmedEmail) +			} +			if confirmationToken := user.ConfirmationToken; confirmationToken != "" { +				lm.Delete("confirmationtoken", confirmationToken) +			} +		}, +	}) +	c.cache.SetTTL(time.Minute*5, false) +	c.cache.Start(time.Second * 10) +	return c +} + +// GetByID attempts to fetch a user from the cache by its ID, you will receive a copy for thread-safety +func (c *UserCache) GetByID(id string) (*gtsmodel.User, bool) { +	return c.cache.Get(id) +} + +// GetByAccountID attempts to fetch a user from the cache by its account ID, you will receive a copy for thread-safety +func (c *UserCache) GetByAccountID(accountID string) (*gtsmodel.User, bool) { +	return c.cache.GetBy("accountid", accountID) +} + +// GetByEmail attempts to fetch a user from the cache by its email address, you will receive a copy for thread-safety +func (c *UserCache) GetByEmail(email string) (*gtsmodel.User, bool) { +	return c.cache.GetBy("email", email) +} + +// GetByUnconfirmedEmail attempts to fetch a user from the cache by its confirmation token, you will receive a copy for thread-safety +func (c *UserCache) GetByConfirmationToken(token string) (*gtsmodel.User, bool) { +	return c.cache.GetBy("confirmationtoken", token) +} + +// Put places a user in the cache, ensuring that the object place is a copy for thread-safety +func (c *UserCache) Put(user *gtsmodel.User) { +	if user == nil || user.ID == "" { +		panic("invalid user") +	} +	c.cache.Set(user.ID, copyUser(user)) +} + +// Invalidate invalidates one user from the cache using the ID of the user as key. +func (c *UserCache) Invalidate(userID string) { +	c.cache.Invalidate(userID) +} + +func copyUser(user *gtsmodel.User) *gtsmodel.User { +	return >smodel.User{ +		ID:                     user.ID, +		CreatedAt:              user.CreatedAt, +		UpdatedAt:              user.UpdatedAt, +		Email:                  user.Email, +		AccountID:              user.AccountID, +		Account:                nil, +		EncryptedPassword:      user.EncryptedPassword, +		SignUpIP:               user.SignUpIP, +		CurrentSignInAt:        user.CurrentSignInAt, +		CurrentSignInIP:        user.CurrentSignInIP, +		LastSignInAt:           user.LastSignInAt, +		LastSignInIP:           user.LastSignInIP, +		SignInCount:            user.SignInCount, +		InviteID:               user.InviteID, +		ChosenLanguages:        user.ChosenLanguages, +		FilteredLanguages:      user.FilteredLanguages, +		Locale:                 user.Locale, +		CreatedByApplicationID: user.CreatedByApplicationID, +		CreatedByApplication:   nil, +		LastEmailedAt:          user.LastEmailedAt, +		ConfirmationToken:      user.ConfirmationToken, +		ConfirmationSentAt:     user.ConfirmationSentAt, +		ConfirmedAt:            user.ConfirmedAt, +		UnconfirmedEmail:       user.UnconfirmedEmail, +		Moderator:              copyBoolPtr(user.Moderator), +		Admin:                  copyBoolPtr(user.Admin), +		Disabled:               copyBoolPtr(user.Disabled), +		Approved:               copyBoolPtr(user.Approved), +		ResetPasswordToken:     user.ResetPasswordToken, +		ResetPasswordSentAt:    user.ResetPasswordSentAt, +	} +} diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index f66ed0294..9fa78eca0 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -30,6 +30,7 @@ import (  	"time"  	"github.com/superseriousbusiness/gotosocial/internal/ap" +	"github.com/superseriousbusiness/gotosocial/internal/cache"  	"github.com/superseriousbusiness/gotosocial/internal/config"  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -40,7 +41,8 @@ import (  )  type adminDB struct { -	conn *DBConn +	conn      *DBConn +	userCache *cache.UserCache  }  func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { @@ -175,6 +177,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,  		Exec(ctx); err != nil {  		return nil, a.conn.ProcessError(err)  	} +	a.userCache.Put(u)  	return u, nil  } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 1579fae76..70a44d4c1 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -87,6 +87,7 @@ type DBService struct {  	db.Session  	db.Status  	db.Timeline +	db.User  	conn *DBConn  } @@ -181,13 +182,15 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {  	notifCache.SetTTL(time.Minute*5, false)  	notifCache.Start(time.Second * 10) -	// Prepare domain block cache +	// Prepare other caches  	blockCache := cache.NewDomainBlockCache() +	userCache := cache.NewUserCache()  	ps := &DBService{  		Account: accounts,  		Admin: &adminDB{ -			conn: conn, +			conn:      conn, +			userCache: userCache,  		},  		Basic: &basicDB{  			conn: conn, @@ -219,7 +222,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {  		},  		Status:   status,  		Timeline: timeline, -		conn:     conn, +		User: &userDB{ +			conn:  conn, +			cache: userCache, +		}, +		conn: conn,  	}  	// we can confidently return this useable service now diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go new file mode 100644 index 000000000..46f24c4b2 --- /dev/null +++ b/internal/db/bundb/user.go @@ -0,0 +1,151 @@ +/* +   GoToSocial +   Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( +	"context" +	"time" + +	"github.com/superseriousbusiness/gotosocial/internal/cache" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun" +) + +type userDB struct { +	conn  *DBConn +	cache *cache.UserCache +} + +func (u *userDB) newUserQ(user *gtsmodel.User) *bun.SelectQuery { +	return u.conn. +		NewSelect(). +		Model(user). +		Relation("Account") +} + +func (u *userDB) getUser(ctx context.Context, cacheGet func() (*gtsmodel.User, bool), dbQuery func(*gtsmodel.User) error) (*gtsmodel.User, db.Error) { +	// Attempt to fetch cached user +	user, cached := cacheGet() + +	if !cached { +		user = >smodel.User{} + +		// Not cached! Perform database query +		err := dbQuery(user) +		if err != nil { +			return nil, u.conn.ProcessError(err) +		} + +		// Place in the cache +		u.cache.Put(user) +	} + +	return user, nil +} + +func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { +	return u.getUser( +		ctx, +		func() (*gtsmodel.User, bool) { +			return u.cache.GetByID(id) +		}, +		func(user *gtsmodel.User) error { +			return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx) +		}, +	) +} + +func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) { +	return u.getUser( +		ctx, +		func() (*gtsmodel.User, bool) { +			return u.cache.GetByAccountID(accountID) +		}, +		func(user *gtsmodel.User) error { +			return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx) +		}, +	) +} + +func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) { +	return u.getUser( +		ctx, +		func() (*gtsmodel.User, bool) { +			return u.cache.GetByEmail(emailAddress) +		}, +		func(user *gtsmodel.User) error { +			return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx) +		}, +	) +} + +func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) { +	return u.getUser( +		ctx, +		func() (*gtsmodel.User, bool) { +			return u.cache.GetByConfirmationToken(confirmationToken) +		}, +		func(user *gtsmodel.User) error { +			return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx) +		}, +	) +} + +func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) (*gtsmodel.User, db.Error) { +	if _, err := u.conn. +		NewInsert(). +		Model(user). +		Exec(ctx); err != nil { +		return nil, u.conn.ProcessError(err) +	} + +	u.cache.Put(user) +	return user, nil +} + +func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, db.Error) { +	// Update the user's last-updated +	user.UpdatedAt = time.Now() + +	if _, err := u.conn. +		NewUpdate(). +		Model(user). +		WherePK(). +		Column(columns...). +		Exec(ctx); err != nil { +		return nil, u.conn.ProcessError(err) +	} + +	u.cache.Invalidate(user.ID) +	return user, nil +} + +func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { +	if _, err := u.conn. +		NewDelete(). +		Model(>smodel.User{ID: userID}). +		WherePK(). +		Exec(ctx); err != nil { +		return u.conn.ProcessError(err) +	} + +	u.cache.Invalidate(userID) +	return nil +} diff --git a/internal/db/bundb/user_test.go b/internal/db/bundb/user_test.go new file mode 100644 index 000000000..6ad59fc8e --- /dev/null +++ b/internal/db/bundb/user_test.go @@ -0,0 +1,73 @@ +/* +   GoToSocial +   Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( +	"context" +	"testing" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type UserTestSuite struct { +	BunDBStandardTestSuite +} + +func (suite *UserTestSuite) TestGetUser() { +	user, err := suite.db.GetUserByID(context.Background(), suite.testUsers["local_account_1"].ID) +	suite.NoError(err) +	suite.NotNil(user) +} + +func (suite *UserTestSuite) TestGetUserByEmailAddress() { +	user, err := suite.db.GetUserByEmailAddress(context.Background(), suite.testUsers["local_account_1"].Email) +	suite.NoError(err) +	suite.NotNil(user) +} + +func (suite *UserTestSuite) TestGetUserByAccountID() { +	user, err := suite.db.GetUserByAccountID(context.Background(), suite.testAccounts["local_account_1"].ID) +	suite.NoError(err) +	suite.NotNil(user) +} + +func (suite *UserTestSuite) TestUpdateUserSelectedColumns() { +	testUser := suite.testUsers["local_account_1"] +	user := >smodel.User{ +		ID:     testUser.ID, +		Email:  "whatever", +		Locale: "es", +	} + +	user, err := suite.db.UpdateUser(context.Background(), user, "email", "locale") +	suite.NoError(err) +	suite.NotNil(user) + +	dbUser, err := suite.db.GetUserByID(context.Background(), testUser.ID) +	suite.NoError(err) +	suite.NotNil(dbUser) +	suite.Equal("whatever", dbUser.Email) +	suite.Equal("es", dbUser.Locale) +	suite.Equal(testUser.AccountID, dbUser.AccountID) +} + +func TestUserTestSuite(t *testing.T) { +	suite.Run(t, new(UserTestSuite)) +} diff --git a/internal/db/db.go b/internal/db/db.go index 0c1f2602a..52a76ecdb 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -44,6 +44,7 @@ type DB interface {  	Session  	Status  	Timeline +	User  	/*  		USEFUL CONVERSION FUNCTIONS diff --git a/internal/db/user.go b/internal/db/user.go new file mode 100644 index 000000000..a4d48db56 --- /dev/null +++ b/internal/db/user.go @@ -0,0 +1,42 @@ +/* +   GoToSocial +   Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import ( +	"context" + +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// User contains functions related to user getting/setting/creation. +type User interface { +	// GetUserByID returns one user with the given ID, or an error if something goes wrong. +	GetUserByID(ctx context.Context, id string) (*gtsmodel.User, Error) +	// GetUserByAccountID returns one user by its account ID, or an error if something goes wrong. +	GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, Error) +	// GetUserByID returns one user with the given email address, or an error if something goes wrong. +	GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error) +	// GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong. +	GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error) +	// UpdateUser updates one user by its primary key. If columns is set, only given columns +	// will be updated. If not set, all columns will be updated. +	UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, Error) +	// DeleteUserByID deletes one user by its ID. +	DeleteUserByID(ctx context.Context, userID string) Error +} diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index 3a5a9c622..3758a4000 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -70,13 +70,14 @@ func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origi  	// 1. Delete account's application(s), clients, and oauth tokens  	// we only need to do this step for local account since remote ones won't have any tokens or applications on our server +	var user *gtsmodel.User  	if account.Domain == "" {  		// see if we can get a user for this account -		u := >smodel.User{} -		if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil { +		var err error +		if user, err = p.db.GetUserByAccountID(ctx, account.ID); err == nil {  			// we got one! select all tokens with the user's ID  			tokens := []*gtsmodel.Token{} -			if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil { +			if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil {  				// we have some tokens to delete  				for _, t := range tokens {  					// delete client(s) associated with this token @@ -240,9 +241,11 @@ selectStatusesLoop:  	// TODO  	// 16. Delete account's user -	l.Debug("deleting account user") -	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, >smodel.User{}); err != nil { -		return gtserror.NewErrorInternalError(err) +	if user != nil { +		l.Debug("deleting account user") +		if err := p.db.DeleteUserByID(ctx, user.ID); err != nil { +			return gtserror.NewErrorInternalError(err) +		}  	}  	// 17. Delete account's timeline @@ -288,8 +291,8 @@ func (p *processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,  	if form.DeleteOriginID == account.ID {  		// the account owner themself has requested deletion via the API, get their user from the db -		user := >smodel.User{} -		if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, user); err != nil { +		user, err := p.db.GetUserByAccountID(ctx, account.ID) +		if err != nil {  			return gtserror.NewErrorInternalError(err)  		} diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go index d7c9c5d82..a688e3732 100644 --- a/internal/processing/fromclientapi.go +++ b/internal/processing/fromclientapi.go @@ -29,7 +29,6 @@ import (  	"github.com/superseriousbusiness/activity/pub"  	"github.com/superseriousbusiness/activity/streams"  	"github.com/superseriousbusiness/gotosocial/internal/ap" -	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/messages" @@ -138,8 +137,8 @@ func (p *processor) processCreateAccountFromClientAPI(ctx context.Context, clien  	}  	// get the user this account belongs to -	user := >smodel.User{} -	if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, user); err != nil { +	user, err := p.db.GetUserByAccountID(ctx, account.ID) +	if err != nil {  		return err  	} diff --git a/internal/processing/fromfederator_test.go b/internal/processing/fromfederator_test.go index 9337482c4..22d0ba9f4 100644 --- a/internal/processing/fromfederator_test.go +++ b/internal/processing/fromfederator_test.go @@ -370,7 +370,7 @@ func (suite *FromFederatorTestSuite) TestProcessAccountDelete() {  	// no statuses from foss satan should be left in the database  	if !testrig.WaitFor(func() bool {  		s, err := suite.db.GetAccountStatuses(ctx, deletedAccount.ID, 0, false, false, "", "", false, false, false) -		return  s == nil && err == db.ErrNoEntries +		return s == nil && err == db.ErrNoEntries  	}) {  		suite.FailNow("timeout waiting for statuses to be deleted")  	} diff --git a/internal/processing/instance.go b/internal/processing/instance.go index b7418659a..32a4de6f0 100644 --- a/internal/processing/instance.go +++ b/internal/processing/instance.go @@ -142,8 +142,8 @@ func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe  			return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername))  		}  		// make sure it has a user associated with it -		contactUser := >smodel.User{} -		if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil { +		contactUser, err := p.db.GetUserByAccountID(ctx, contactAccount.ID) +		if err != nil {  			return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername))  		}  		// suspended accounts cannot be contact accounts diff --git a/internal/processing/streaming/authorize.go b/internal/processing/streaming/authorize.go index 70e4741e1..cb152b676 100644 --- a/internal/processing/streaming/authorize.go +++ b/internal/processing/streaming/authorize.go @@ -40,8 +40,8 @@ func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken s  		return nil, gtserror.NewErrorUnauthorized(err)  	} -	user := >smodel.User{} -	if err := p.db.GetByID(ctx, uid, user); err != nil { +	user, err := p.db.GetUserByID(ctx, uid) +	if err != nil {  		if err == db.ErrNoEntries {  			err := fmt.Errorf("no user found for validated uid %s", uid)  			return nil, gtserror.NewErrorUnauthorized(err) diff --git a/internal/processing/user/emailconfirm.go b/internal/processing/user/emailconfirm.go index 6bffce7d9..5a68383b8 100644 --- a/internal/processing/user/emailconfirm.go +++ b/internal/processing/user/emailconfirm.go @@ -89,8 +89,8 @@ func (p *processor) ConfirmEmail(ctx context.Context, token string) (*gtsmodel.U  		return nil, gtserror.NewErrorNotFound(errors.New("no token provided"))  	} -	user := >smodel.User{} -	if err := p.db.GetWhere(ctx, []db.Where{{Key: "confirmation_token", Value: token}}, user); err != nil { +	user, err := p.db.GetUserByConfirmationToken(ctx, token) +	if err != nil {  		if err == db.ErrNoEntries {  			return nil, gtserror.NewErrorNotFound(err)  		} diff --git a/internal/typeutils/internaltofrontend_test.go b/internal/typeutils/internaltofrontend_test.go index 6028344b4..a13e5255c 100644 --- a/internal/typeutils/internaltofrontend_test.go +++ b/internal/typeutils/internaltofrontend_test.go @@ -46,9 +46,9 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontend() {  func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiStruct() {  	testAccount := suite.testAccounts["local_account_1"] // take zork for this test  	testEmoji := suite.testEmojis["rainbow"] -	 +  	testAccount.Emojis = []*gtsmodel.Emoji{testEmoji} -	 +  	apiAccount, err := suite.typeconverter.AccountToAPIAccountPublic(context.Background(), testAccount)  	suite.NoError(err)  	suite.NotNil(apiAccount) @@ -61,9 +61,9 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiStruct()  func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiIDs() {  	testAccount := suite.testAccounts["local_account_1"] // take zork for this test  	testEmoji := suite.testEmojis["rainbow"] -	 +  	testAccount.EmojiIDs = []string{testEmoji.ID} -	 +  	apiAccount, err := suite.typeconverter.AccountToAPIAccountPublic(context.Background(), testAccount)  	suite.NoError(err)  	suite.NotNil(apiAccount) diff --git a/internal/visibility/statusvisible.go b/internal/visibility/statusvisible.go index 15d8544ad..c62ebb0af 100644 --- a/internal/visibility/statusvisible.go +++ b/internal/visibility/statusvisible.go @@ -68,8 +68,8 @@ func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Statu  	// if the target user doesn't exist (anymore) then the status also shouldn't be visible  	// note: we only do this for local users  	if targetAccount.Domain == "" { -		targetUser := >smodel.User{} -		if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil { +		targetUser, err := f.db.GetUserByAccountID(ctx, targetAccount.ID) +		if err != nil {  			l.Debug("target user could not be selected")  			if err == db.ErrNoEntries {  				return false, nil @@ -98,8 +98,8 @@ func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Statu  	// if the requesting user doesn't exist (anymore) then the status also shouldn't be visible  	// note: we only do this for local users  	if requestingAccount.Domain == "" { -		requestingUser := >smodel.User{} -		if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil { +		requestingUser, err := f.db.GetUserByAccountID(ctx, requestingAccount.ID) +		if err != nil {  			// if the requesting account is local but doesn't have a corresponding user in the db this is a problem  			l.Debug("requesting user could not be selected")  			if err == db.ErrNoEntries {  | 
