diff options
Diffstat (limited to 'internal/db/bundb/admin.go')
-rw-r--r-- | internal/db/bundb/admin.go | 82 |
1 files changed, 52 insertions, 30 deletions
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 9fa78eca0..44861a4bb 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -22,7 +22,6 @@ import ( "context" "crypto/rand" "crypto/rsa" - "database/sql" "fmt" "net" "net/mail" @@ -37,21 +36,26 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/uris" + "github.com/uptrace/bun" "golang.org/x/crypto/bcrypt" ) +// generate RSA keys of this length +const rsaKeyBits = 2048 + type adminDB struct { - conn *DBConn - userCache *cache.UserCache + conn *DBConn + userCache *cache.UserCache + accountCache *cache.AccountCache } func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { q := a.conn. NewSelect(). - Model(>smodel.Account{}). - Where("username = ?", username). - Where("domain = ?", nil) - + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? = ?", bun.Ident("account.username"), username). + Where("? IS NULL", bun.Ident("account.domain")) return a.conn.NotExists(ctx, q) } @@ -64,29 +68,31 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db. domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ // check if the email domain is blocked - if err := a.conn. + emailDomainBlockedQ := a.conn. NewSelect(). - Model(>smodel.EmailDomainBlock{}). - Where("domain = ?", domain). - Scan(ctx); err == nil { - // fail because we found something + TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")). + Column("email_domain_block.id"). + Where("? = ?", bun.Ident("email_domain_block.domain"), domain) + emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ) + if err != nil { + return false, err + } + if emailDomainBlocked { return false, fmt.Errorf("email domain %s is blocked", domain) - } else if err != sql.ErrNoRows { - return false, a.conn.ProcessError(err) } // check if this email is associated with a user already q := a.conn. NewSelect(). - Model(>smodel.User{}). - Where("email = ?", email). - WhereOr("unconfirmed_email = ?", email) - + TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). + Column("user.id"). + Where("? = ?", bun.Ident("user.email"), email). + WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email) return a.conn.NotExists(ctx, q) } func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) + key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits) if err != nil { log.Errorf("error creating new rsa key: %s", err) return nil, err @@ -94,13 +100,20 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, // if something went wrong while creating a user, we might already have an account, so check here first... acct := >smodel.Account{} - q := a.conn.NewSelect(). + if err := a.conn. + NewSelect(). Model(acct). - Where("username = ?", username). - WhereGroup(" AND ", whereEmptyOrNull("domain")) + Where("? = ?", bun.Ident("account.username"), username). + WhereGroup(" AND ", whereEmptyOrNull("account.domain")). + Scan(ctx); err != nil { + err = a.conn.ProcessError(err) + if err != db.ErrNoEntries { + log.Errorf("error checking for existing account: %s", err) + return nil, err + } - if err := q.Scan(ctx); err != nil { - // we just don't have an account yet so create one before we proceed + // if we have db.ErrNoEntries, we just don't have an + // account yet so create one before we proceed accountURIs := uris.GenerateURIsForAccount(username) accountID, err := id.NewRandomULID() if err != nil { @@ -126,14 +139,19 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, FeaturedCollectionURI: accountURIs.CollectionURI, } + // insert the new account! if _, err = a.conn. NewInsert(). Model(acct). Exec(ctx); err != nil { return nil, a.conn.ProcessError(err) } + a.accountCache.Put(acct) } + // we either created or already had an account by now, + // so proceed with creating a user for that account + pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("error hashing password: %s", err) @@ -171,6 +189,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, u.Moderator = &moderator } + // insert the user! if _, err = a.conn. NewInsert(). Model(u). @@ -187,9 +206,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { q := a.conn. NewSelect(). - Model(>smodel.Account{}). - Where("username = ?", username). - WhereGroup(" AND ", whereEmptyOrNull("domain")) + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? = ?", bun.Ident("account.username"), username). + WhereGroup(" AND ", whereEmptyOrNull("account.domain")) exists, err := a.conn.Exists(ctx, q) if err != nil { @@ -200,7 +220,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { return nil } - key, err := rsa.GenerateKey(rand.Reader, 2048) + key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits) if err != nil { log.Errorf("error creating new rsa key: %s", err) return err @@ -237,6 +257,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { return a.conn.ProcessError(err) } + a.accountCache.Put(acct) log.Infof("instance account %s CREATED with id %s", username, acct.ID) return nil } @@ -248,8 +269,9 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { // check if instance entry already exists q := a.conn. NewSelect(). - Model(>smodel.Instance{}). - Where("domain = ?", host) + Column("instance.id"). + TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). + Where("? = ?", bun.Ident("instance.domain"), host) exists, err := a.conn.Exists(ctx, q) if err != nil { |