summaryrefslogtreecommitdiff
path: root/internal/db/bundb/admin.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/admin.go')
-rw-r--r--internal/db/bundb/admin.go82
1 files changed, 52 insertions, 30 deletions
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go
index 9fa78eca0..44861a4bb 100644
--- a/internal/db/bundb/admin.go
+++ b/internal/db/bundb/admin.go
@@ -22,7 +22,6 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
- "database/sql"
"fmt"
"net"
"net/mail"
@@ -37,21 +36,26 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/uris"
+ "github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
)
+// generate RSA keys of this length
+const rsaKeyBits = 2048
+
type adminDB struct {
- conn *DBConn
- userCache *cache.UserCache
+ conn *DBConn
+ userCache *cache.UserCache
+ accountCache *cache.AccountCache
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
q := a.conn.
NewSelect().
- Model(&gtsmodel.Account{}).
- Where("username = ?", username).
- Where("domain = ?", nil)
-
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ Column("account.id").
+ Where("? = ?", bun.Ident("account.username"), username).
+ Where("? IS NULL", bun.Ident("account.domain"))
return a.conn.NotExists(ctx, q)
}
@@ -64,29 +68,31 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
- if err := a.conn.
+ emailDomainBlockedQ := a.conn.
NewSelect().
- Model(&gtsmodel.EmailDomainBlock{}).
- Where("domain = ?", domain).
- Scan(ctx); err == nil {
- // fail because we found something
+ TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")).
+ Column("email_domain_block.id").
+ Where("? = ?", bun.Ident("email_domain_block.domain"), domain)
+ emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ)
+ if err != nil {
+ return false, err
+ }
+ if emailDomainBlocked {
return false, fmt.Errorf("email domain %s is blocked", domain)
- } else if err != sql.ErrNoRows {
- return false, a.conn.ProcessError(err)
}
// check if this email is associated with a user already
q := a.conn.
NewSelect().
- Model(&gtsmodel.User{}).
- Where("email = ?", email).
- WhereOr("unconfirmed_email = ?", email)
-
+ TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")).
+ Column("user.id").
+ Where("? = ?", bun.Ident("user.email"), email).
+ WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email)
return a.conn.NotExists(ctx, q)
}
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
- key, err := rsa.GenerateKey(rand.Reader, 2048)
+ key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
if err != nil {
log.Errorf("error creating new rsa key: %s", err)
return nil, err
@@ -94,13 +100,20 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
// if something went wrong while creating a user, we might already have an account, so check here first...
acct := &gtsmodel.Account{}
- q := a.conn.NewSelect().
+ if err := a.conn.
+ NewSelect().
Model(acct).
- Where("username = ?", username).
- WhereGroup(" AND ", whereEmptyOrNull("domain"))
+ Where("? = ?", bun.Ident("account.username"), username).
+ WhereGroup(" AND ", whereEmptyOrNull("account.domain")).
+ Scan(ctx); err != nil {
+ err = a.conn.ProcessError(err)
+ if err != db.ErrNoEntries {
+ log.Errorf("error checking for existing account: %s", err)
+ return nil, err
+ }
- if err := q.Scan(ctx); err != nil {
- // we just don't have an account yet so create one before we proceed
+ // if we have db.ErrNoEntries, we just don't have an
+ // account yet so create one before we proceed
accountURIs := uris.GenerateURIsForAccount(username)
accountID, err := id.NewRandomULID()
if err != nil {
@@ -126,14 +139,19 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
FeaturedCollectionURI: accountURIs.CollectionURI,
}
+ // insert the new account!
if _, err = a.conn.
NewInsert().
Model(acct).
Exec(ctx); err != nil {
return nil, a.conn.ProcessError(err)
}
+ a.accountCache.Put(acct)
}
+ // we either created or already had an account by now,
+ // so proceed with creating a user for that account
+
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err)
@@ -171,6 +189,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
u.Moderator = &moderator
}
+ // insert the user!
if _, err = a.conn.
NewInsert().
Model(u).
@@ -187,9 +206,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
q := a.conn.
NewSelect().
- Model(&gtsmodel.Account{}).
- Where("username = ?", username).
- WhereGroup(" AND ", whereEmptyOrNull("domain"))
+ TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
+ Column("account.id").
+ Where("? = ?", bun.Ident("account.username"), username).
+ WhereGroup(" AND ", whereEmptyOrNull("account.domain"))
exists, err := a.conn.Exists(ctx, q)
if err != nil {
@@ -200,7 +220,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
return nil
}
- key, err := rsa.GenerateKey(rand.Reader, 2048)
+ key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
if err != nil {
log.Errorf("error creating new rsa key: %s", err)
return err
@@ -237,6 +257,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
return a.conn.ProcessError(err)
}
+ a.accountCache.Put(acct)
log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil
}
@@ -248,8 +269,9 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
// check if instance entry already exists
q := a.conn.
NewSelect().
- Model(&gtsmodel.Instance{}).
- Where("domain = ?", host)
+ Column("instance.id").
+ TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
+ Where("? = ?", bun.Ident("instance.domain"), host)
exists, err := a.conn.Exists(ctx, q)
if err != nil {