diff options
Diffstat (limited to 'internal/db/pg.go')
-rw-r--r-- | internal/db/pg.go | 43 |
1 files changed, 28 insertions, 15 deletions
diff --git a/internal/db/pg.go b/internal/db/pg.go index 24a57d8a5..647285032 100644 --- a/internal/db/pg.go +++ b/internal/db/pg.go @@ -37,7 +37,7 @@ import ( "github.com/google/uuid" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/util" "golang.org/x/crypto/bcrypt" ) @@ -46,14 +46,14 @@ import ( type postgresService struct { config *config.Config conn *pg.DB - log *logrus.Entry + log *logrus.Logger cancel context.CancelFunc federationDB pub.Database } -// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. +// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. // Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection. -func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (DB, error) { +func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (DB, error) { opts, err := derivePGOptions(c) if err != nil { return nil, fmt.Errorf("could not create postgres service: %s", err) @@ -67,7 +67,7 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry // this will break the logfmt format we normally log in, // since we can't choose where pg outputs to and it defaults to // stdout. So use this option with care! - if log.Logger.GetLevel() >= logrus.TraceLevel { + if log.GetLevel() >= logrus.TraceLevel { conn.AddQueryHook(pgdebug.DebugHook{ // Print all queries. Verbose: true, @@ -95,7 +95,7 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry cancel: cancel, } - federatingDB := newFederatingDB(ps, c) + federatingDB := NewFederatingDB(ps, c, log) ps.federationDB = federatingDB // we can confidently return this useable postgres service now @@ -109,8 +109,8 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry // derivePGOptions takes an application config and returns either a ready-to-use *pg.Options // with sensible defaults, or an error if it's not satisfied by the provided config. func derivePGOptions(c *config.Config) (*pg.Options, error) { - if strings.ToUpper(c.DBConfig.Type) != dbTypePostgres { - return nil, fmt.Errorf("expected db type of %s but got %s", dbTypePostgres, c.DBConfig.Type) + if strings.ToUpper(c.DBConfig.Type) != DBTypePostgres { + return nil, fmt.Errorf("expected db type of %s but got %s", DBTypePostgres, c.DBConfig.Type) } // validate port @@ -341,6 +341,16 @@ func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.A return nil } +func (ps *postgresService) GetLocalAccountByUsername(username string, account *gtsmodel.Account) error { + if err := ps.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]gtsmodel.FollowRequest) error { if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil { if err == pg.ErrNoRows { @@ -456,21 +466,23 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr return nil, err } - uris := util.GenerateURIs(username, ps.config.Protocol, ps.config.Host) + newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host) a := >smodel.Account{ Username: username, DisplayName: username, Reason: reason, - URL: uris.UserURL, + URL: newAccountURIs.UserURL, PrivateKey: key, PublicKey: &key.PublicKey, + PublicKeyURI: newAccountURIs.PublicKeyURI, ActorType: gtsmodel.ActivityStreamsPerson, - URI: uris.UserURI, - InboxURL: uris.InboxURI, - OutboxURL: uris.OutboxURI, - FollowersURL: uris.FollowersURI, - FeaturedCollectionURL: uris.CollectionURI, + URI: newAccountURIs.UserURI, + InboxURI: newAccountURIs.InboxURI, + OutboxURI: newAccountURIs.OutboxURI, + FollowersURI: newAccountURIs.FollowersURI, + FollowingURI: newAccountURIs.FollowingURI, + FeaturedCollectionURI: newAccountURIs.CollectionURI, } if _, err = ps.conn.Model(a).Insert(); err != nil { return nil, err @@ -566,6 +578,7 @@ func (ps *postgresService) GetAvatarForAccountID(avatar *gtsmodel.MediaAttachmen } func (ps *postgresService) Blocked(account1 string, account2 string) (bool, error) { + // TODO: check domain blocks as well var blocked bool if err := ps.conn.Model(>smodel.Block{}). Where("account_id = ?", account1).Where("target_account_id = ?", account2). |