diff options
Diffstat (limited to 'internal/db/pg.go')
-rw-r--r-- | internal/db/pg.go | 495 |
1 files changed, 453 insertions, 42 deletions
diff --git a/internal/db/pg.go b/internal/db/pg.go index 487af184f..df01132c2 100644 --- a/internal/db/pg.go +++ b/internal/db/pg.go @@ -20,8 +20,12 @@ package db import ( "context" + "crypto/rand" + "crypto/rsa" "errors" "fmt" + "net" + "net/mail" "regexp" "strings" "time" @@ -30,14 +34,17 @@ import ( "github.com/go-pg/pg/extra/pgdebug" "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" - "github.com/gotosocial/gotosocial/internal/config" - "github.com/gotosocial/gotosocial/internal/gtsmodel" "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db/model" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/superseriousbusiness/gotosocial/pkg/mastotypes" + "golang.org/x/crypto/bcrypt" ) // postgresService satisfies the DB interface type postgresService struct { - config *config.DBConfig + config *config.Config conn *pg.DB log *logrus.Entry cancel context.CancelFunc @@ -46,7 +53,7 @@ type postgresService struct { // newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. // Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection. -func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (*postgresService, error) { +func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (DB, error) { opts, err := derivePGOptions(c) if err != nil { return nil, fmt.Errorf("could not create postgres service: %s", err) @@ -98,18 +105,18 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry return nil, errors.New("db connection timeout") } - // we can confidently return this useable postgres service now - return &postgresService{ - config: c.DBConfig, - conn: conn, - log: log, - cancel: cancel, - federationDB: newPostgresFederation(conn), - }, nil -} + ps := &postgresService{ + config: c, + conn: conn, + log: log, + cancel: cancel, + } -func (ps *postgresService) Federation() pub.Database { - return ps.federationDB + federatingDB := newFederatingDB(ps, c) + ps.federationDB = federatingDB + + // we can confidently return this useable postgres service now + return ps, nil } /* @@ -168,9 +175,29 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { } /* - EXTRA FUNCTIONS + FEDERATION FUNCTIONALITY */ +func (ps *postgresService) Federation() pub.Database { + return ps.federationDB +} + +/* + BASIC DB FUNCTIONALITY +*/ + +func (ps *postgresService) CreateTable(i interface{}) error { + return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{ + IfNotExists: true, + }) +} + +func (ps *postgresService) DropTable(i interface{}) error { + return ps.conn.Model(i).DropTable(&orm.DropTableOptions{ + IfExists: true, + }) +} + func (ps *postgresService) Stop(ctx context.Context) error { ps.log.Info("closing db connection") if err := ps.conn.Close(); err != nil { @@ -181,11 +208,15 @@ func (ps *postgresService) Stop(ctx context.Context) error { return nil } +func (ps *postgresService) IsHealthy(ctx context.Context) error { + return ps.conn.Ping(ctx) +} + func (ps *postgresService) CreateSchema(ctx context.Context) error { models := []interface{}{ - (*gtsmodel.Account)(nil), - (*gtsmodel.Status)(nil), - (*gtsmodel.User)(nil), + (*model.Account)(nil), + (*model.Status)(nil), + (*model.User)(nil), } ps.log.Info("creating db schema") @@ -202,32 +233,35 @@ func (ps *postgresService) CreateSchema(ctx context.Context) error { return nil } -func (ps *postgresService) IsHealthy(ctx context.Context) error { - return ps.conn.Ping(ctx) -} - -func (ps *postgresService) CreateTable(i interface{}) error { - return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }) -} - -func (ps *postgresService) DropTable(i interface{}) error { - return ps.conn.Model(i).DropTable(&orm.DropTableOptions{ - IfExists: true, - }) -} - func (ps *postgresService) GetByID(id string, i interface{}) error { - return ps.conn.Model(i).Where("id = ?", id).Select() + if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + + } + return nil } func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error { - return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select() + if err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil } func (ps *postgresService) GetAll(i interface{}) error { - return ps.conn.Model(i).Select() + if err := ps.conn.Model(i).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil } func (ps *postgresService) Put(i interface{}) error { @@ -236,16 +270,393 @@ func (ps *postgresService) Put(i interface{}) error { } func (ps *postgresService) UpdateByID(id string, i interface{}) error { - _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert() + if _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error { + _, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update() return err } func (ps *postgresService) DeleteByID(id string, i interface{}) error { - _, err := ps.conn.Model(i).Where("id = ?", id).Delete() - return err + if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil } func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error { - _, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete() + if _, err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Delete(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +/* + HANDY SHORTCUTS +*/ + +func (ps *postgresService) GetAccountByUserID(userID string, account *model.Account) error { + user := &model.User{ + ID: userID, + } + if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error { + if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]model.Follow) error { + if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error { + if err := ps.conn.Model(followers).Where("target_account_id = ?", accountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) GetStatusesByAccountID(accountID string, statuses *[]model.Status) error { + if err := ps.conn.Model(statuses).Where("account_id = ?", accountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error { + q := ps.conn.Model(statuses).Order("created_at DESC") + if limit != 0 { + q = q.Limit(limit) + } + if accountID != "" { + q = q.Where("account_id = ?", accountID) + } + if err := q.Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *model.Status) error { + if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil + +} + +func (ps *postgresService) IsUsernameAvailable(username string) error { + // if no error we fail because it means we found something + // if error but it's not pg.ErrNoRows then we fail + // if err is pg.ErrNoRows we're good, we found nothing so continue + if err := ps.conn.Model(&model.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil { + return fmt.Errorf("username %s already in use", username) + } else if err != pg.ErrNoRows { + return fmt.Errorf("db error: %s", err) + } + return nil +} + +func (ps *postgresService) IsEmailAvailable(email string) error { + // parse the domain from the email + m, err := mail.ParseAddress(email) + if err != nil { + return fmt.Errorf("error parsing email address %s: %s", email, err) + } + domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ + + // check if the email domain is blocked + if err := ps.conn.Model(&model.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil { + // fail because we found something + return fmt.Errorf("email domain %s is blocked", domain) + } else if err != pg.ErrNoRows { + // fail because we got an unexpected error + return fmt.Errorf("db error: %s", err) + } + + // check if this email is associated with a user already + if err := ps.conn.Model(&model.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil { + // fail because we found something + return fmt.Errorf("email %s already in use", email) + } else if err != pg.ErrNoRows { + // fail because we got an unexpected error + return fmt.Errorf("db error: %s", err) + } + return nil +} + +func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + ps.log.Errorf("error creating new rsa key: %s", err) + return nil, err + } + + uris := util.GenerateURIs(username, ps.config.Protocol, ps.config.Host) + + a := &model.Account{ + Username: username, + DisplayName: username, + Reason: reason, + URL: uris.UserURL, + PrivateKey: key, + PublicKey: &key.PublicKey, + ActorType: "Person", + URI: uris.UserURI, + InboxURL: uris.InboxURL, + OutboxURL: uris.OutboxURL, + FollowersURL: uris.FollowersURL, + FeaturedCollectionURL: uris.CollectionURL, + } + if _, err = ps.conn.Model(a).Insert(); err != nil { + return nil, err + } + + pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, fmt.Errorf("error hashing password: %s", err) + } + u := &model.User{ + AccountID: a.ID, + EncryptedPassword: string(pw), + SignUpIP: signUpIP, + Locale: locale, + UnconfirmedEmail: email, + CreatedByApplicationID: appID, + Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user + } + if _, err = ps.conn.Model(u).Insert(); err != nil { + return nil, err + } + + return u, nil +} + +func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error { + _, err := ps.conn.Model(mediaAttachment).Insert() return err } + +func (ps *postgresService) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error { + if err := ps.conn.Model(header).Where("account_id = ?", accountID).Where("header = ?", true).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error { + if err := ps.conn.Model(avatar).Where("account_id = ?", accountID).Where("avatar = ?", true).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +/* + CONVERSION FUNCTIONS +*/ + +// AccountToMastoSensitive takes an internal account model and transforms it into an account ready to be served through the API. +// The resulting account fits the specifications for the path /api/v1/accounts/verify_credentials, as described here: +// https://docs.joinmastodon.org/methods/accounts/. Note that it's *sensitive* because it's only meant to be exposed to the user +// that the account actually belongs to. +func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) { + // we can build this sensitive account easily by first getting the public account.... + mastoAccount, err := ps.AccountToMastoPublic(a) + if err != nil { + return nil, err + } + + // then adding the Source object to it... + + // check pending follow requests aimed at this account + fr := []model.FollowRequest{} + if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting follow requests: %s", err) + } + } + var frc int + if fr != nil { + frc = len(fr) + } + + mastoAccount.Source = &mastotypes.Source{ + Privacy: a.Privacy, + Sensitive: a.Sensitive, + Language: a.Language, + Note: a.Note, + Fields: mastoAccount.Fields, + FollowRequestsCount: frc, + } + + return mastoAccount, nil +} + +func (ps *postgresService) AccountToMastoPublic(a *model.Account) (*mastotypes.Account, error) { + // count followers + followers := []model.Follow{} + if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting followers: %s", err) + } + } + var followersCount int + if followers != nil { + followersCount = len(followers) + } + + // count following + following := []model.Follow{} + if err := ps.GetFollowingByAccountID(a.ID, &following); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting following: %s", err) + } + } + var followingCount int + if following != nil { + followingCount = len(following) + } + + // count statuses + statuses := []model.Status{} + if err := ps.GetStatusesByAccountID(a.ID, &statuses); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting last statuses: %s", err) + } + } + var statusesCount int + if statuses != nil { + statusesCount = len(statuses) + } + + // check when the last status was + lastStatus := &model.Status{} + if err := ps.GetLastStatusForAccountID(a.ID, lastStatus); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting last status: %s", err) + } + } + var lastStatusAt string + if lastStatus != nil { + lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339) + } + + // build the avatar and header URLs + avi := &model.MediaAttachment{} + if err := ps.GetAvatarForAccountID(avi, a.ID); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting avatar: %s", err) + } + } + aviURL := avi.File.Path + aviURLStatic := avi.Thumbnail.Path + + header := &model.MediaAttachment{} + if err := ps.GetHeaderForAccountID(avi, a.ID); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting header: %s", err) + } + } + headerURL := header.File.Path + headerURLStatic := header.Thumbnail.Path + + // get the fields set on this account + fields := []mastotypes.Field{} + for _, f := range a.Fields { + mField := mastotypes.Field{ + Name: f.Name, + Value: f.Value, + } + if !f.VerifiedAt.IsZero() { + mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339) + } + fields = append(fields, mField) + } + + var acct string + if a.Domain != "" { + // this is a remote user + acct = fmt.Sprintf("%s@%s", a.Username, a.Domain) + } else { + // this is a local user + acct = a.Username + } + + return &mastotypes.Account{ + ID: a.ID, + Username: a.Username, + Acct: acct, + DisplayName: a.DisplayName, + Locked: a.Locked, + Bot: a.Bot, + CreatedAt: a.CreatedAt.Format(time.RFC3339), + Note: a.Note, + URL: a.URL, + Avatar: aviURL, + AvatarStatic: aviURLStatic, + Header: headerURL, + HeaderStatic: headerURLStatic, + FollowersCount: followersCount, + FollowingCount: followingCount, + StatusesCount: statusesCount, + LastStatusAt: lastStatusAt, + Emojis: nil, // TODO: implement this + Fields: fields, + }, nil +} |