diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/const.go | 12 | ||||
-rw-r--r-- | internal/db/postgres.go | 91 | ||||
-rw-r--r-- | internal/db/service.go | 10 |
3 files changed, 89 insertions, 24 deletions
diff --git a/internal/db/const.go b/internal/db/const.go index ab0a5c03f..ca4ab3922 100644 --- a/internal/db/const.go +++ b/internal/db/const.go @@ -21,15 +21,23 @@ package db import "regexp" const ( - // general db defaults + /* + general db defaults + */ // default database to use in whatever db implementation we have defaultDatabase string = "gotosocial" + // default address should in most cases be overwritten + defaultAddress string = "localhost" - // implementation-specific defaults + /* + implementation-specific defaults + */ // widely-recognised default postgres port postgresDefaultPort int = 5432 + // default user should in most cases be overwritten + postgresDefaultUser string = "postgres" ) var ipv4Regex = regexp.MustCompile(`^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$`) diff --git a/internal/db/postgres.go b/internal/db/postgres.go index 67d0bb4e9..f4cf47406 100644 --- a/internal/db/postgres.go +++ b/internal/db/postgres.go @@ -23,31 +23,69 @@ import ( "errors" "fmt" "net/url" + "time" "github.com/go-fed/activity/streams/vocab" "github.com/go-pg/pg" + "github.com/sirupsen/logrus" ) type postgresService struct { config *Config conn *pg.DB - ready bool + log *logrus.Entry + cancel context.CancelFunc } // newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. // Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection. -func newPostgresService(config *Config) (*postgresService, error) { +func newPostgresService(ctx context.Context, config *Config, log *logrus.Entry) (*postgresService, error) { opts, err := derivePGOptions(config) if err != nil { return nil, fmt.Errorf("could not create postgres service: %s", err) } - conn := pg.Connect(opts) + + readyChan := make(chan interface{}) + opts.OnConnect = func(c *pg.Conn) error { + close(readyChan) + return nil + } + + // create a connection + pgCtx, cancel := context.WithCancel(ctx) + conn := pg.Connect(opts).WithContext(pgCtx) + + // actually *begin* the connection so that we can tell if the db is there + // and listening, and also trigger the opts.OnConnect function passed in above + tx, err := conn.Begin() + if err != nil { + cancel() + return nil, fmt.Errorf("db connection error: %s", err) + } + + // close the transaction we just started so it doesn't hang around + if err := tx.Rollback(); err != nil { + cancel() + return nil, fmt.Errorf("db connection error: %s", err) + } + + // make sure the opts.OnConnect function has been triggered + // and closed the ready channel + select { + case <-readyChan: + log.Infof("postgres connection ready") + case <-time.After(5 * time.Second): + cancel() + return nil, errors.New("db connection timeout") + } + + // we can confidently return this useable postgres service now return &postgresService{ - config, - conn, - false, + config: config, + conn: conn, + log: log, + cancel: cancel, }, nil - } /* @@ -68,22 +106,35 @@ func derivePGOptions(config *Config) (*pg.Options, error) { } // validate address - address := config.Address - if address == "" { - return nil, errors.New("address not provided") + if config.Address == "" { + config.Address = defaultAddress + } + if !hostnameRegex.MatchString(config.Address) && !ipv4Regex.MatchString(config.Address) && config.Address != "localhost" { + return nil, fmt.Errorf("address %s was neither an ipv4 address nor a valid hostname", config.Address) + } + + // validate username + if config.User == "" { + config.User = postgresDefaultUser } - if !hostnameRegex.MatchString(address) && !ipv4Regex.MatchString(address) { - return nil, fmt.Errorf("address %s was neither an ipv4 address nor a valid hostname", address) + + // validate that there's a password + if config.Password == "" { + return nil, errors.New("no password set") + } + + // validate database + if config.Database == "" { + config.Database = defaultDatabase } + // We can rely on the pg library we're using to set + // sensible defaults for everything we don't set here. options := &pg.Options{ Addr: fmt.Sprintf("%s:%d", config.Address, config.Port), User: config.User, Password: config.Password, Database: config.Database, - OnConnect: func(c *pg.Conn) error { - return nil - }, } return options, nil @@ -176,6 +227,12 @@ func (ps *postgresService) Liked(c context.Context, actorIRI *url.URL) (follower EXTRA FUNCTIONS */ -func (ps *postgresService) Ready() bool { - return false +func (ps *postgresService) Stop(ctx context.Context) error { + ps.log.Info("closing db connection") + if err := ps.conn.Close(); err != nil { + // only cancel if there's a problem closing the db + ps.cancel() + return err + } + return nil } diff --git a/internal/db/service.go b/internal/db/service.go index 6163b3c69..9a1d3ce2c 100644 --- a/internal/db/service.go +++ b/internal/db/service.go @@ -19,10 +19,12 @@ package db import ( + "context" "fmt" "strings" "github.com/go-fed/activity/pub" + "github.com/sirupsen/logrus" ) const dbTypePostgres string = "POSTGRES" @@ -39,9 +41,7 @@ type Service interface { /* ANY ADDITIONAL DESIRED FUNCTIONS */ - - // Ready indicates whether the database is ready to handle queries and whatnot. - Ready() bool + Stop(context.Context) error } // Config provides configuration options for the database connection @@ -57,10 +57,10 @@ type Config struct { // NewService returns a new database service that satisfies the Service interface and, by extension, // the go-fed database interface described here: https://github.com/go-fed/activity/blob/master/pub/database.go -func NewService(config *Config) (Service, error) { +func NewService(context context.Context, config *Config, log *logrus.Logger) (Service, error) { switch strings.ToUpper(config.Type) { case dbTypePostgres: - return newPostgresService(config) + return newPostgresService(context, config, log.WithField("service", "db")) default: return nil, fmt.Errorf("database type %s not supported", config.Type) } |