diff options
Diffstat (limited to 'internal/db/bundb/bundb.go')
| -rw-r--r-- | internal/db/bundb/bundb.go | 410 |
1 files changed, 410 insertions, 0 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go new file mode 100644 index 000000000..49ed09cbd --- /dev/null +++ b/internal/db/bundb/bundb.go @@ -0,0 +1,410 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "crypto/tls" + "crypto/x509" + "database/sql" + "encoding/pem" + "errors" + "fmt" + "os" + "strings" + "time" + + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/stdlib" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" +) + +const ( + dbTypePostgres = "postgres" + dbTypeSqlite = "sqlite" +) + +var registerTables []interface{} = []interface{}{ + >smodel.StatusToEmoji{}, + >smodel.StatusToTag{}, +} + +// bunDBService satisfies the DB interface +type bunDBService struct { + db.Account + db.Admin + db.Basic + db.Domain + db.Instance + db.Media + db.Mention + db.Notification + db.Relationship + db.Session + db.Status + db.Timeline + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. +// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. +func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { + var sqldb *sql.DB + var conn *bun.DB + + // depending on the database type we're trying to create, we need to use a different driver... + switch strings.ToLower(c.DBConfig.Type) { + case dbTypePostgres: + // POSTGRES + opts, err := deriveBunDBPGOptions(c) + if err != nil { + return nil, fmt.Errorf("could not create bundb postgres options: %s", err) + } + sqldb = stdlib.OpenDB(*opts) + conn = bun.NewDB(sqldb, pgdialect.New()) + case dbTypeSqlite: + // SQLITE + // TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite + default: + return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type)) + } + + // actually *begin* the connection so that we can tell if the db is there and listening + if err := conn.Ping(); err != nil { + return nil, fmt.Errorf("db connection error: %s", err) + } + log.Info("connected to database") + + for _, t := range registerTables { + // https://bun.uptrace.dev/orm/many-to-many-relation/ + conn.RegisterModel(t) + } + + ps := &bunDBService{ + Account: &accountDB{ + config: c, + conn: conn, + log: log, + }, + Admin: &adminDB{ + config: c, + conn: conn, + log: log, + }, + Basic: &basicDB{ + config: c, + conn: conn, + log: log, + }, + Domain: &domainDB{ + config: c, + conn: conn, + log: log, + }, + Instance: &instanceDB{ + config: c, + conn: conn, + log: log, + }, + Media: &mediaDB{ + config: c, + conn: conn, + log: log, + }, + Mention: &mentionDB{ + config: c, + conn: conn, + log: log, + }, + Notification: ¬ificationDB{ + config: c, + conn: conn, + log: log, + }, + Relationship: &relationshipDB{ + config: c, + conn: conn, + log: log, + }, + Session: &sessionDB{ + config: c, + conn: conn, + log: log, + }, + Status: &statusDB{ + config: c, + conn: conn, + log: log, + }, + Timeline: &timelineDB{ + config: c, + conn: conn, + log: log, + }, + config: c, + conn: conn, + log: log, + } + + // we can confidently return this useable service now + return ps, nil +} + +/* + HANDY STUFF +*/ + +// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options +// with sensible defaults, or an error if it's not satisfied by the provided config. +func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) { + if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres { + return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type) + } + + // validate port + if c.DBConfig.Port == 0 { + return nil, errors.New("no port set") + } + + // validate address + if c.DBConfig.Address == "" { + return nil, errors.New("no address set") + } + + // validate username + if c.DBConfig.User == "" { + return nil, errors.New("no user set") + } + + // validate that there's a password + if c.DBConfig.Password == "" { + return nil, errors.New("no password set") + } + + // validate database + if c.DBConfig.Database == "" { + return nil, errors.New("no database set") + } + + var tlsConfig *tls.Config + switch c.DBConfig.TLSMode { + case config.DBTLSModeDisable, config.DBTLSModeUnset: + break // nothing to do + case config.DBTLSModeEnable: + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + case config.DBTLSModeRequire: + tlsConfig = &tls.Config{ + InsecureSkipVerify: false, + ServerName: c.DBConfig.Address, + } + } + + if tlsConfig != nil && c.DBConfig.TLSCACert != "" { + // load the system cert pool first -- we'll append the given CA cert to this + certPool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("error fetching system CA cert pool: %s", err) + } + + // open the file itself and make sure there's something in it + caCertBytes, err := os.ReadFile(c.DBConfig.TLSCACert) + if err != nil { + return nil, fmt.Errorf("error opening CA certificate at %s: %s", c.DBConfig.TLSCACert, err) + } + if len(caCertBytes) == 0 { + return nil, fmt.Errorf("ca cert at %s was empty", c.DBConfig.TLSCACert) + } + + // make sure we have a PEM block + caPem, _ := pem.Decode(caCertBytes) + if caPem == nil { + return nil, fmt.Errorf("could not parse cert at %s into PEM", c.DBConfig.TLSCACert) + } + + // parse the PEM block into the certificate + caCert, err := x509.ParseCertificate(caPem.Bytes) + if err != nil { + return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", c.DBConfig.TLSCACert, err) + } + + // we're happy, add it to the existing pool and then use this pool in our tls config + certPool.AddCert(caCert) + tlsConfig.RootCAs = certPool + } + + cfg, _ := pgx.ParseConfig("") + cfg.Host = c.DBConfig.Address + cfg.Port = uint16(c.DBConfig.Port) + cfg.User = c.DBConfig.User + cfg.Password = c.DBConfig.Password + cfg.TLSConfig = tlsConfig + cfg.Database = c.DBConfig.Database + cfg.PreferSimpleProtocol = true + + return cfg, nil +} + +/* + CONVERSION FUNCTIONS +*/ + +// TODO: move these to the type converter, it's bananas that they're here and not there + +func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) { + ogAccount := >smodel.Account{} + if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil { + return nil, err + } + + menchies := []*gtsmodel.Mention{} + for _, a := range targetAccounts { + // A mentioned account looks like "@test@example.org" or just "@test" for a local account + // -- we can guarantee this from the regex that targetAccounts should have been derived from. + // But we still need to do a bit of fiddling to get what we need here -- the username and domain (if given). + + // 1. trim off the first @ + t := strings.TrimPrefix(a, "@") + + // 2. split the username and domain + s := strings.Split(t, "@") + + // 3. if it's length 1 it's a local account, length 2 means remote, anything else means something is wrong + var local bool + switch len(s) { + case 1: + local = true + case 2: + local = false + default: + return nil, fmt.Errorf("mentioned account format '%s' was not valid", a) + } + + var username, domain string + username = s[0] + if !local { + domain = s[1] + } + + // 4. check we now have a proper username and domain + if username == "" || (!local && domain == "") { + return nil, fmt.Errorf("username or domain for '%s' was nil", a) + } + + // okay we're good now, we can start pulling accounts out of the database + mentionedAccount := >smodel.Account{} + var err error + + // match username + account, case insensitive + if local { + // local user -- should have a null domain + err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx) + } else { + // remote user -- should have domain defined + err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx) + } + + if err != nil { + if err == sql.ErrNoRows { + // no result found for this username/domain so just don't include it as a mencho and carry on about our business + ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain) + continue + } + // a serious error has happened so bail + return nil, fmt.Errorf("error getting account with username '%s' and domain '%s': %s", username, domain, err) + } + + // id, createdAt and updatedAt will be populated by the db, so we have everything we need! + menchies = append(menchies, >smodel.Mention{ + StatusID: statusID, + OriginAccountID: ogAccount.ID, + OriginAccountURI: ogAccount.URI, + TargetAccountID: mentionedAccount.ID, + NameString: a, + TargetAccountURI: mentionedAccount.URI, + TargetAccountURL: mentionedAccount.URL, + OriginAccount: mentionedAccount, + }) + } + return menchies, nil +} + +func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) { + newTags := []*gtsmodel.Tag{} + for _, t := range tags { + tag := >smodel.Tag{} + // we can use selectorinsert here to create the new tag if it doesn't exist already + // inserted will be true if this is a new tag we just created + if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil { + if err == sql.ErrNoRows { + // tag doesn't exist yet so populate it + newID, err := id.NewRandomULID() + if err != nil { + return nil, err + } + tag.ID = newID + tag.URL = fmt.Sprintf("%s://%s/tags/%s", ps.config.Protocol, ps.config.Host, t) + tag.Name = t + tag.FirstSeenFromAccountID = originAccountID + tag.CreatedAt = time.Now() + tag.UpdatedAt = time.Now() + tag.Useable = true + tag.Listable = true + } else { + return nil, fmt.Errorf("error getting tag with name %s: %s", t, err) + } + } + + // bail already if the tag isn't useable + if !tag.Useable { + continue + } + tag.LastStatusAt = time.Now() + newTags = append(newTags, tag) + } + return newTags, nil +} + +func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) { + newEmojis := []*gtsmodel.Emoji{} + for _, e := range emojis { + emoji := >smodel.Emoji{} + err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx) + if err != nil { + if err == sql.ErrNoRows { + // no result found for this username/domain so just don't include it as an emoji and carry on about our business + ps.log.Debugf("no emoji found with shortcode %s, skipping it", e) + continue + } + // a serious error has happened so bail + return nil, fmt.Errorf("error getting emoji with shortcode %s: %s", e, err) + } + newEmojis = append(newEmojis, emoji) + } + return newEmojis, nil +} |
