diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/account.go | 24 | ||||
-rw-r--r-- | internal/db/bundb/domain.go | 30 |
2 files changed, 26 insertions, 28 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index ccf7aaa46..56d46a232 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -27,9 +27,11 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" "github.com/uptrace/bun/dialect" ) @@ -82,6 +84,15 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel. } func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { + if domain != "" { + // Normalize the domain as punycode + var err error + domain, err = util.Punify(domain) + if err != nil { + return nil, err + } + } + return a.getAccount( ctx, "Username.Domain", @@ -220,7 +231,10 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func( } func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Account) error { - var err error + var ( + err error + errs = make(gtserror.MultiError, 0, 3) + ) if account.AvatarMediaAttachment == nil && account.AvatarMediaAttachmentID != "" { // Account avatar attachment is not set, fetch from database. @@ -229,7 +243,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou account.AvatarMediaAttachmentID, ) if err != nil { - return fmt.Errorf("error populating account avatar: %w", err) + errs.Append(fmt.Errorf("error populating account avatar: %w", err)) } } @@ -240,7 +254,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou account.HeaderMediaAttachmentID, ) if err != nil { - return fmt.Errorf("error populating account header: %w", err) + errs.Append(fmt.Errorf("error populating account header: %w", err)) } } @@ -251,11 +265,11 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou account.EmojiIDs, ) if err != nil { - return fmt.Errorf("error populating account emojis: %w", err) + errs.Append(fmt.Errorf("error populating account emojis: %w", err)) } } - return nil + return errs.Combine() } func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error { diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index b9d03e98f..5c92645de 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -20,14 +20,13 @@ package bundb import ( "context" "net/url" - "strings" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/uptrace/bun" - "golang.org/x/net/idna" ) type domainDB struct { @@ -35,22 +34,10 @@ type domainDB struct { state *state.State } -// normalizeDomain converts the given domain to lowercase -// then to punycode (for international domain names). -// -// Returns the resulting domain or an error if the -// punycode conversion fails. -func normalizeDomain(domain string) (out string, err error) { - out = strings.ToLower(domain) - out, err = idna.ToASCII(out) - return out, err -} - func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error { - var err error - // Normalize the domain as punycode - block.Domain, err = normalizeDomain(block.Domain) + var err error + block.Domain, err = util.Punify(block.Domain) if err != nil { return err } @@ -69,10 +56,8 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain } func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { - var err error - // Normalize the domain as punycode - domain, err = normalizeDomain(domain) + domain, err := util.Punify(domain) if err != nil { return nil, err } @@ -98,9 +83,8 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel } func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { - var err error - - domain, err = normalizeDomain(domain) + // Normalize the domain as punycode + domain, err := util.Punify(domain) if err != nil { return err } @@ -121,7 +105,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { // Normalize the domain as punycode - domain, err := normalizeDomain(domain) + domain, err := util.Punify(domain) if err != nil { return false, err } |