summaryrefslogtreecommitdiff
path: root/internal/db/pg.go
diff options
context:
space:
mode:
authorLibravatar Tobi Smethurst <31960611+tsmethurst@users.noreply.github.com>2021-04-01 20:46:45 +0200
committerLibravatar GitHub <noreply@github.com>2021-04-01 20:46:45 +0200
commit71a49e2b43218d34f97b2276c43bdeb2df4a53d2 (patch)
tree201c370b16cc5446740660f81f342e8171e9903f /internal/db/pg.go
parentOauth/token (#7) (diff)
downloadgotosocial-71a49e2b43218d34f97b2276c43bdeb2df4a53d2.tar.xz
Api/v1/accounts (#8)
* start work on accounts module * plodding away on the accounts endpoint * groundwork for other account routes * add password validator * validation utils * require account approval flags * comments * comments * go fmt * comments * add distributor stub * rename api to federator * tidy a bit * validate new account requests * rename r router * comments * add domain blocks * add some more shortcuts * add some more shortcuts * check email + username availability * email block checking for signups * chunking away at it * tick off a few more things * some fiddling with tests * add mock package * relocate repo * move mocks around * set app id on new signups * initialize oauth server properly * rename oauth server * proper mocking tests * go fmt ./... * add required fields * change name of func * move validation to account.go * more tests! * add some file utility tools * add mediaconfig * new shortcut * add some more fields * add followrequest model * add notify * update mastotypes * mock out storage interface * start building media interface * start on update credentials * mess about with media a bit more * test image manipulation * media more or less working * account update nearly working * rearranging my package ;) ;) ;) * phew big stuff!!!! * fix type checking * *fiddles* * Add CreateTables func * account registration flow working * tidy * script to step through auth flow * add a lil helper for generating user uris * fiddling with federation a bit * update progress * Tidying and linting
Diffstat (limited to 'internal/db/pg.go')
-rw-r--r--internal/db/pg.go495
1 files changed, 453 insertions, 42 deletions
diff --git a/internal/db/pg.go b/internal/db/pg.go
index 487af184f..df01132c2 100644
--- a/internal/db/pg.go
+++ b/internal/db/pg.go
@@ -20,8 +20,12 @@ package db
import (
"context"
+ "crypto/rand"
+ "crypto/rsa"
"errors"
"fmt"
+ "net"
+ "net/mail"
"regexp"
"strings"
"time"
@@ -30,14 +34,17 @@ import (
"github.com/go-pg/pg/extra/pgdebug"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
- "github.com/gotosocial/gotosocial/internal/config"
- "github.com/gotosocial/gotosocial/internal/gtsmodel"
"github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db/model"
+ "github.com/superseriousbusiness/gotosocial/internal/util"
+ "github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
+ "golang.org/x/crypto/bcrypt"
)
// postgresService satisfies the DB interface
type postgresService struct {
- config *config.DBConfig
+ config *config.Config
conn *pg.DB
log *logrus.Entry
cancel context.CancelFunc
@@ -46,7 +53,7 @@ type postgresService struct {
// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
-func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (*postgresService, error) {
+func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (DB, error) {
opts, err := derivePGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create postgres service: %s", err)
@@ -98,18 +105,18 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry
return nil, errors.New("db connection timeout")
}
- // we can confidently return this useable postgres service now
- return &postgresService{
- config: c.DBConfig,
- conn: conn,
- log: log,
- cancel: cancel,
- federationDB: newPostgresFederation(conn),
- }, nil
-}
+ ps := &postgresService{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ }
-func (ps *postgresService) Federation() pub.Database {
- return ps.federationDB
+ federatingDB := newFederatingDB(ps, c)
+ ps.federationDB = federatingDB
+
+ // we can confidently return this useable postgres service now
+ return ps, nil
}
/*
@@ -168,9 +175,29 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
}
/*
- EXTRA FUNCTIONS
+ FEDERATION FUNCTIONALITY
*/
+func (ps *postgresService) Federation() pub.Database {
+ return ps.federationDB
+}
+
+/*
+ BASIC DB FUNCTIONALITY
+*/
+
+func (ps *postgresService) CreateTable(i interface{}) error {
+ return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
+ IfNotExists: true,
+ })
+}
+
+func (ps *postgresService) DropTable(i interface{}) error {
+ return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
+ IfExists: true,
+ })
+}
+
func (ps *postgresService) Stop(ctx context.Context) error {
ps.log.Info("closing db connection")
if err := ps.conn.Close(); err != nil {
@@ -181,11 +208,15 @@ func (ps *postgresService) Stop(ctx context.Context) error {
return nil
}
+func (ps *postgresService) IsHealthy(ctx context.Context) error {
+ return ps.conn.Ping(ctx)
+}
+
func (ps *postgresService) CreateSchema(ctx context.Context) error {
models := []interface{}{
- (*gtsmodel.Account)(nil),
- (*gtsmodel.Status)(nil),
- (*gtsmodel.User)(nil),
+ (*model.Account)(nil),
+ (*model.Status)(nil),
+ (*model.User)(nil),
}
ps.log.Info("creating db schema")
@@ -202,32 +233,35 @@ func (ps *postgresService) CreateSchema(ctx context.Context) error {
return nil
}
-func (ps *postgresService) IsHealthy(ctx context.Context) error {
- return ps.conn.Ping(ctx)
-}
-
-func (ps *postgresService) CreateTable(i interface{}) error {
- return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- })
-}
-
-func (ps *postgresService) DropTable(i interface{}) error {
- return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
- IfExists: true,
- })
-}
-
func (ps *postgresService) GetByID(id string, i interface{}) error {
- return ps.conn.Model(i).Where("id = ?", id).Select()
+ if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+
+ }
+ return nil
}
func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error {
- return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select()
+ if err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
}
func (ps *postgresService) GetAll(i interface{}) error {
- return ps.conn.Model(i).Select()
+ if err := ps.conn.Model(i).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
}
func (ps *postgresService) Put(i interface{}) error {
@@ -236,16 +270,393 @@ func (ps *postgresService) Put(i interface{}) error {
}
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
- _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert()
+ if _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
+ _, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
- _, err := ps.conn.Model(i).Where("id = ?", id).Delete()
- return err
+ if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
}
func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error {
- _, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete()
+ if _, err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Delete(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+/*
+ HANDY SHORTCUTS
+*/
+
+func (ps *postgresService) GetAccountByUserID(userID string, account *model.Account) error {
+ user := &model.User{
+ ID: userID,
+ }
+ if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error {
+ if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]model.Follow) error {
+ if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error {
+ if err := ps.conn.Model(followers).Where("target_account_id = ?", accountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) GetStatusesByAccountID(accountID string, statuses *[]model.Status) error {
+ if err := ps.conn.Model(statuses).Where("account_id = ?", accountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error {
+ q := ps.conn.Model(statuses).Order("created_at DESC")
+ if limit != 0 {
+ q = q.Limit(limit)
+ }
+ if accountID != "" {
+ q = q.Where("account_id = ?", accountID)
+ }
+ if err := q.Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *model.Status) error {
+ if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+
+}
+
+func (ps *postgresService) IsUsernameAvailable(username string) error {
+ // if no error we fail because it means we found something
+ // if error but it's not pg.ErrNoRows then we fail
+ // if err is pg.ErrNoRows we're good, we found nothing so continue
+ if err := ps.conn.Model(&model.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
+ return fmt.Errorf("username %s already in use", username)
+ } else if err != pg.ErrNoRows {
+ return fmt.Errorf("db error: %s", err)
+ }
+ return nil
+}
+
+func (ps *postgresService) IsEmailAvailable(email string) error {
+ // parse the domain from the email
+ m, err := mail.ParseAddress(email)
+ if err != nil {
+ return fmt.Errorf("error parsing email address %s: %s", email, err)
+ }
+ domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
+
+ // check if the email domain is blocked
+ if err := ps.conn.Model(&model.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
+ // fail because we found something
+ return fmt.Errorf("email domain %s is blocked", domain)
+ } else if err != pg.ErrNoRows {
+ // fail because we got an unexpected error
+ return fmt.Errorf("db error: %s", err)
+ }
+
+ // check if this email is associated with a user already
+ if err := ps.conn.Model(&model.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
+ // fail because we found something
+ return fmt.Errorf("email %s already in use", email)
+ } else if err != pg.ErrNoRows {
+ // fail because we got an unexpected error
+ return fmt.Errorf("db error: %s", err)
+ }
+ return nil
+}
+
+func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) {
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ ps.log.Errorf("error creating new rsa key: %s", err)
+ return nil, err
+ }
+
+ uris := util.GenerateURIs(username, ps.config.Protocol, ps.config.Host)
+
+ a := &model.Account{
+ Username: username,
+ DisplayName: username,
+ Reason: reason,
+ URL: uris.UserURL,
+ PrivateKey: key,
+ PublicKey: &key.PublicKey,
+ ActorType: "Person",
+ URI: uris.UserURI,
+ InboxURL: uris.InboxURL,
+ OutboxURL: uris.OutboxURL,
+ FollowersURL: uris.FollowersURL,
+ FeaturedCollectionURL: uris.CollectionURL,
+ }
+ if _, err = ps.conn.Model(a).Insert(); err != nil {
+ return nil, err
+ }
+
+ pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
+ if err != nil {
+ return nil, fmt.Errorf("error hashing password: %s", err)
+ }
+ u := &model.User{
+ AccountID: a.ID,
+ EncryptedPassword: string(pw),
+ SignUpIP: signUpIP,
+ Locale: locale,
+ UnconfirmedEmail: email,
+ CreatedByApplicationID: appID,
+ Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
+ }
+ if _, err = ps.conn.Model(u).Insert(); err != nil {
+ return nil, err
+ }
+
+ return u, nil
+}
+
+func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error {
+ _, err := ps.conn.Model(mediaAttachment).Insert()
return err
}
+
+func (ps *postgresService) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error {
+ if err := ps.conn.Model(header).Where("account_id = ?", accountID).Where("header = ?", true).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error {
+ if err := ps.conn.Model(avatar).Where("account_id = ?", accountID).Where("avatar = ?", true).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+/*
+ CONVERSION FUNCTIONS
+*/
+
+// AccountToMastoSensitive takes an internal account model and transforms it into an account ready to be served through the API.
+// The resulting account fits the specifications for the path /api/v1/accounts/verify_credentials, as described here:
+// https://docs.joinmastodon.org/methods/accounts/. Note that it's *sensitive* because it's only meant to be exposed to the user
+// that the account actually belongs to.
+func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) {
+ // we can build this sensitive account easily by first getting the public account....
+ mastoAccount, err := ps.AccountToMastoPublic(a)
+ if err != nil {
+ return nil, err
+ }
+
+ // then adding the Source object to it...
+
+ // check pending follow requests aimed at this account
+ fr := []model.FollowRequest{}
+ if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil {
+ if _, ok := err.(ErrNoEntries); !ok {
+ return nil, fmt.Errorf("error getting follow requests: %s", err)
+ }
+ }
+ var frc int
+ if fr != nil {
+ frc = len(fr)
+ }
+
+ mastoAccount.Source = &mastotypes.Source{
+ Privacy: a.Privacy,
+ Sensitive: a.Sensitive,
+ Language: a.Language,
+ Note: a.Note,
+ Fields: mastoAccount.Fields,
+ FollowRequestsCount: frc,
+ }
+
+ return mastoAccount, nil
+}
+
+func (ps *postgresService) AccountToMastoPublic(a *model.Account) (*mastotypes.Account, error) {
+ // count followers
+ followers := []model.Follow{}
+ if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil {
+ if _, ok := err.(ErrNoEntries); !ok {
+ return nil, fmt.Errorf("error getting followers: %s", err)
+ }
+ }
+ var followersCount int
+ if followers != nil {
+ followersCount = len(followers)
+ }
+
+ // count following
+ following := []model.Follow{}
+ if err := ps.GetFollowingByAccountID(a.ID, &following); err != nil {
+ if _, ok := err.(ErrNoEntries); !ok {
+ return nil, fmt.Errorf("error getting following: %s", err)
+ }
+ }
+ var followingCount int
+ if following != nil {
+ followingCount = len(following)
+ }
+
+ // count statuses
+ statuses := []model.Status{}
+ if err := ps.GetStatusesByAccountID(a.ID, &statuses); err != nil {
+ if _, ok := err.(ErrNoEntries); !ok {
+ return nil, fmt.Errorf("error getting last statuses: %s", err)
+ }
+ }
+ var statusesCount int
+ if statuses != nil {
+ statusesCount = len(statuses)
+ }
+
+ // check when the last status was
+ lastStatus := &model.Status{}
+ if err := ps.GetLastStatusForAccountID(a.ID, lastStatus); err != nil {
+ if _, ok := err.(ErrNoEntries); !ok {
+ return nil, fmt.Errorf("error getting last status: %s", err)
+ }
+ }
+ var lastStatusAt string
+ if lastStatus != nil {
+ lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339)
+ }
+
+ // build the avatar and header URLs
+ avi := &model.MediaAttachment{}
+ if err := ps.GetAvatarForAccountID(avi, a.ID); err != nil {
+ if _, ok := err.(ErrNoEntries); !ok {
+ return nil, fmt.Errorf("error getting avatar: %s", err)
+ }
+ }
+ aviURL := avi.File.Path
+ aviURLStatic := avi.Thumbnail.Path
+
+ header := &model.MediaAttachment{}
+ if err := ps.GetHeaderForAccountID(avi, a.ID); err != nil {
+ if _, ok := err.(ErrNoEntries); !ok {
+ return nil, fmt.Errorf("error getting header: %s", err)
+ }
+ }
+ headerURL := header.File.Path
+ headerURLStatic := header.Thumbnail.Path
+
+ // get the fields set on this account
+ fields := []mastotypes.Field{}
+ for _, f := range a.Fields {
+ mField := mastotypes.Field{
+ Name: f.Name,
+ Value: f.Value,
+ }
+ if !f.VerifiedAt.IsZero() {
+ mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339)
+ }
+ fields = append(fields, mField)
+ }
+
+ var acct string
+ if a.Domain != "" {
+ // this is a remote user
+ acct = fmt.Sprintf("%s@%s", a.Username, a.Domain)
+ } else {
+ // this is a local user
+ acct = a.Username
+ }
+
+ return &mastotypes.Account{
+ ID: a.ID,
+ Username: a.Username,
+ Acct: acct,
+ DisplayName: a.DisplayName,
+ Locked: a.Locked,
+ Bot: a.Bot,
+ CreatedAt: a.CreatedAt.Format(time.RFC3339),
+ Note: a.Note,
+ URL: a.URL,
+ Avatar: aviURL,
+ AvatarStatic: aviURLStatic,
+ Header: headerURL,
+ HeaderStatic: headerURLStatic,
+ FollowersCount: followersCount,
+ FollowingCount: followingCount,
+ StatusesCount: statusesCount,
+ LastStatusAt: lastStatusAt,
+ Emojis: nil, // TODO: implement this
+ Fields: fields,
+ }, nil
+}