diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/admin.go | 12 | ||||
-rw-r--r-- | internal/db/bundb/bundb.go | 4 | ||||
-rw-r--r-- | internal/db/bundb/bundbnew_test.go | 52 | ||||
-rw-r--r-- | internal/db/bundb/errors.go | 4 | ||||
-rw-r--r-- | internal/db/error.go | 15 |
5 files changed, 78 insertions, 9 deletions
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 37c0db6d3..a92834f9c 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -94,13 +94,13 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, // if something went wrong while creating a user, we might already have an account, so check here first... acct := >smodel.Account{} - err = a.conn.NewSelect(). + q := a.conn.NewSelect(). Model(acct). Where("username = ?", username). - WhereGroup(" AND ", whereEmptyOrNull("domain")). - Scan(ctx) - if err != nil { - // we just don't have an account yet so create one + WhereGroup(" AND ", whereEmptyOrNull("domain")) + + if err := q.Scan(ctx); err != nil { + // we just don't have an account yet so create one before we proceed accountURIs := uris.GenerateURIsForAccount(username) accountID, err := id.NewRandomULID() if err != nil { @@ -125,6 +125,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, FollowingURI: accountURIs.FollowingURI, FeaturedCollectionURI: accountURIs.CollectionURI, } + if _, err = a.conn. NewInsert(). Model(acct). @@ -158,6 +159,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, if emailVerified { u.ConfirmedAt = time.Now() u.Email = email + u.UnconfirmedEmail = "" } if admin { diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 47fe4fb47..ebdbc4ba2 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -204,7 +204,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { } func sqliteConn(ctx context.Context) (*DBConn, error) { + // validate db address has actually been set dbAddress := viper.GetString(config.Keys.DbAddress) + if dbAddress == "" { + return nil, fmt.Errorf("'%s' was not set when attempting to start sqlite", config.Keys.DbAddress) + } // Drop anything fancy from DB address dbAddress = strings.Split(dbAddress, "?")[0] diff --git a/internal/db/bundb/bundbnew_test.go b/internal/db/bundb/bundbnew_test.go new file mode 100644 index 000000000..40a05cb50 --- /dev/null +++ b/internal/db/bundb/bundbnew_test.go @@ -0,0 +1,52 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "context" + "testing" + + "github.com/spf13/viper" + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" +) + +type BundbNewTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *BundbNewTestSuite) TestCreateNewDB() { + // create a new db with standard test settings + db, err := bundb.NewBunDBService(context.Background()) + suite.NoError(err) + suite.NotNil(db) +} + +func (suite *BundbNewTestSuite) TestCreateNewSqliteDBNoAddress() { + // create a new db with no address specified + viper.Set(config.Keys.DbAddress, "") + db, err := bundb.NewBunDBService(context.Background()) + suite.EqualError(err, "'db-address' was not set when attempting to start sqlite") + suite.Nil(db) +} + +func TestBundbNewTestSuite(t *testing.T) { + suite.Run(t, new(BundbNewTestSuite)) +} diff --git a/internal/db/bundb/errors.go b/internal/db/bundb/errors.go index 7d0157373..113679226 100644 --- a/internal/db/bundb/errors.go +++ b/internal/db/bundb/errors.go @@ -19,7 +19,7 @@ func processPostgresError(err error) db.Error { // (https://www.postgresql.org/docs/10/errcodes-appendix.html) switch pgErr.Code { case "23505" /* unique_violation */ : - return db.ErrAlreadyExists + return db.NewErrAlreadyExists(pgErr.Message) default: return err } @@ -36,7 +36,7 @@ func processSQLiteError(err error) db.Error { // Handle supplied error code: switch sqliteErr.Code() { case sqlite3.SQLITE_CONSTRAINT_UNIQUE, sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY: - return db.ErrAlreadyExists + return db.NewErrAlreadyExists(err.Error()) default: return err } diff --git a/internal/db/error.go b/internal/db/error.go index 984f96401..9ac0b6aa0 100644 --- a/internal/db/error.go +++ b/internal/db/error.go @@ -28,8 +28,19 @@ var ( ErrNoEntries Error = fmt.Errorf("no entries") // ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found. ErrMultipleEntries Error = fmt.Errorf("multiple entries") - // ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db. - ErrAlreadyExists Error = fmt.Errorf("already exists") // ErrUnknown denotes an unknown database error. ErrUnknown Error = fmt.Errorf("unknown error") ) + +// ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db. +type ErrAlreadyExists struct { + message string +} + +func (e *ErrAlreadyExists) Error() string { + return e.message +} + +func NewErrAlreadyExists(msg string) error { + return &ErrAlreadyExists{message: msg} +} |