diff options
| author | 2024-02-12 11:52:12 +0000 | |
|---|---|---|
| committer | 2024-02-12 11:52:12 +0000 | |
| commit | ede8f43635c37a05b72f250632d0b236a2709ca9 (patch) | |
| tree | 77cc153ca12e85f5a5ae966dea004d3ac375680f /internal/db | |
| parent | [docs] Fix a few things in the bare metal install (#2624) (diff) | |
| download | gotosocial-ede8f43635c37a05b72f250632d0b236a2709ca9.tar.xz | |
[performance] temporarily cache account status counts to reduce no. account counts (#2620)
* temporarily cache account status counts to reduce no. account counts
* whoops, forgot to initAccountCounts()
* use already fetched cache capacity value
* make cache a ptr type
* whoops, use count instead of just select
* fix to correctly use the transaction
* properly wrap that tx :innocent:
* correctly wrap both tx types
* outline retryOnBusy() to allow the fast path to be inlined
* return err on context cancelled
* remove unnecessary storage of context in stmt, fix Exec and Query interface implementations
* shutup linter
Diffstat (limited to 'internal/db')
| -rw-r--r-- | internal/db/bundb/account.go | 58 | ||||
| -rw-r--r-- | internal/db/bundb/drivers.go | 149 | 
2 files changed, 169 insertions, 38 deletions
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 4b4c78726..e0d574f62 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -532,20 +532,56 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g  }  func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) { -	return a.db. -		NewSelect(). -		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). -		Where("? = ?", bun.Ident("status.account_id"), accountID). -		Count(ctx) +	counts, err := a.getAccountStatusCounts(ctx, accountID) +	return counts.Statuses, err  }  func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) { -	return a.db. -		NewSelect(). -		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). -		Where("? = ?", bun.Ident("status.account_id"), accountID). -		Where("? IS NOT NULL", bun.Ident("status.pinned_at")). -		Count(ctx) +	counts, err := a.getAccountStatusCounts(ctx, accountID) +	return counts.Pinned, err +} + +func (a *accountDB) getAccountStatusCounts(ctx context.Context, accountID string) (struct { +	Statuses int +	Pinned   int +}, error) { +	// Check for an already cached copy of account status counts. +	counts, ok := a.state.Caches.GTS.AccountCounts.Get(accountID) +	if ok { +		return counts, nil +	} + +	if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +		var err error + +		// Scan database for account statuses. +		counts.Statuses, err = tx.NewSelect(). +			Table("statuses"). +			Where("? = ?", bun.Ident("account_id"), accountID). +			Count(ctx) +		if err != nil { +			return err +		} + +		// Scan database for pinned statuses. +		counts.Pinned, err = tx.NewSelect(). +			Table("statuses"). +			Where("? = ?", bun.Ident("account_id"), accountID). +			Where("? IS NOT NULL", bun.Ident("pinned_at")). +			Count(ctx) +		if err != nil { +			return err +		} + +		return nil +	}); err != nil { +		return counts, err +	} + +	// Store this account counts result in the cache. +	a.state.Caches.GTS.AccountCounts.Set(accountID, counts) + +	return counts, nil  }  func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) { diff --git a/internal/db/bundb/drivers.go b/internal/db/bundb/drivers.go index 14d84e6fa..a70b598d2 100644 --- a/internal/db/bundb/drivers.go +++ b/internal/db/bundb/drivers.go @@ -36,14 +36,14 @@ var (  	sqliteDriver   = getSQLiteDriver()  ) +//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver +func getSQLiteDriver() *sqlite.Driver +  func init() {  	sql.Register("pgx-gts", &PostgreSQLDriver{})  	sql.Register("sqlite-gts", &SQLiteDriver{})  } -//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver -func getSQLiteDriver() *sqlite.Driver -  // PostgreSQLDriver is our own wrapper around the  // pgx/stdlib.Driver{} type in order to wrap further  // SQL driver types with our own err processing. @@ -66,7 +66,10 @@ func (c *PostgreSQLConn) Begin() (driver.Tx, error) {  func (c *PostgreSQLConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {  	tx, err := c.conn.BeginTx(ctx, opts)  	err = processPostgresError(err) -	return tx, err +	if err != nil { +		return nil, err +	} +	return &PostgreSQLTx{tx}, nil  }  func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) { @@ -74,13 +77,16 @@ func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) {  }  func (c *PostgreSQLConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { -	stmt, err := c.conn.PrepareContext(ctx, query) +	st, err := c.conn.PrepareContext(ctx, query)  	err = processPostgresError(err) -	return stmt, err +	if err != nil { +		return nil, err +	} +	return &PostgreSQLStmt{stmt: st.(stmt)}, nil  } -func (c *PostgreSQLConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) { -	return c.ExecContext(context.Background(), query, args) +func (c *PostgreSQLConn) Exec(query string, args []driver.Value) (driver.Result, error) { +	return c.ExecContext(context.Background(), query, toNamedValues(args))  }  func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { @@ -89,8 +95,8 @@ func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []d  	return result, err  } -func (c *PostgreSQLConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) { -	return c.QueryContext(context.Background(), query, args) +func (c *PostgreSQLConn) Query(query string, args []driver.Value) (driver.Rows, error) { +	return c.QueryContext(context.Background(), query, toNamedValues(args))  }  func (c *PostgreSQLConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { @@ -115,6 +121,28 @@ func (tx *PostgreSQLTx) Rollback() error {  	return processPostgresError(err)  } +type PostgreSQLStmt struct{ stmt } + +func (stmt *PostgreSQLStmt) Exec(args []driver.Value) (driver.Result, error) { +	return stmt.ExecContext(context.Background(), toNamedValues(args)) +} + +func (stmt *PostgreSQLStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { +	res, err := stmt.stmt.ExecContext(ctx, args) +	err = processSQLiteError(err) +	return res, err +} + +func (stmt *PostgreSQLStmt) Query(args []driver.Value) (driver.Rows, error) { +	return stmt.QueryContext(context.Background(), toNamedValues(args)) +} + +func (stmt *PostgreSQLStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { +	rows, err := stmt.stmt.QueryContext(ctx, args) +	err = processSQLiteError(err) +	return rows, err +} +  // SQLiteDriver is our own wrapper around the  // sqlite.Driver{} type in order to wrap further  // SQL driver types with our own functionality, @@ -141,6 +169,9 @@ func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx dri  		err = processSQLiteError(err)  		return err  	}) +	if err != nil { +		return nil, err +	}  	return &SQLiteTx{Context: ctx, Tx: tx}, nil  } @@ -148,17 +179,20 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {  	return c.PrepareContext(context.Background(), query)  } -func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) { +func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (st driver.Stmt, err error) {  	err = retryOnBusy(ctx, func() error { -		stmt, err = c.conn.PrepareContext(ctx, query) +		st, err = c.conn.PrepareContext(ctx, query)  		err = processSQLiteError(err)  		return err  	}) -	return +	if err != nil { +		return nil, err +	} +	return &SQLiteStmt{st.(stmt)}, nil  } -func (c *SQLiteConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) { -	return c.ExecContext(context.Background(), query, args) +func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) { +	return c.ExecContext(context.Background(), query, toNamedValues(args))  }  func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { @@ -170,8 +204,8 @@ func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []drive  	return  } -func (c *SQLiteConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) { -	return c.QueryContext(context.Background(), query, args) +func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { +	return c.QueryContext(context.Background(), query, toNamedValues(args))  }  func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { @@ -213,29 +247,64 @@ func (tx *SQLiteTx) Rollback() (err error) {  	return  } +type SQLiteStmt struct{ stmt } + +func (stmt *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { +	return stmt.ExecContext(context.Background(), toNamedValues(args)) +} + +func (stmt *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { +	err = retryOnBusy(ctx, func() error { +		res, err = stmt.stmt.ExecContext(ctx, args) +		err = processSQLiteError(err) +		return err +	}) +	return +} + +func (stmt *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { +	return stmt.QueryContext(context.Background(), toNamedValues(args)) +} + +func (stmt *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { +	err = retryOnBusy(ctx, func() error { +		rows, err = stmt.stmt.QueryContext(ctx, args) +		err = processSQLiteError(err) +		return err +	}) +	return +} +  type conn interface {  	driver.Conn  	driver.ConnPrepareContext +	driver.Execer //nolint:staticcheck  	driver.ExecerContext +	driver.Queryer //nolint:staticcheck  	driver.QueryerContext  	driver.ConnBeginTx  } +type stmt interface { +	driver.Stmt +	driver.StmtExecContext +	driver.StmtQueryContext +} +  // retryOnBusy will retry given function on returned 'errBusy'.  func retryOnBusy(ctx context.Context, fn func() error) error { +	if err := fn(); err != errBusy { +		return err +	} +	return retryOnBusySlow(ctx, fn) +} + +// retryOnBusySlow is the outlined form of retryOnBusy, to allow the fast path (i.e. only +// 1 attempt) to be inlined, leaving the slow retry loop to be a separate function call. +func retryOnBusySlow(ctx context.Context, fn func() error) error {  	var backoff time.Duration  	for i := 0; ; i++ { -		// Perform func. -		err := fn() - -		if err != errBusy { -			// May be nil, or may be -			// some other error, either -			// way return here. -			return err -		} -  		// backoff according to a multiplier of 2ms * 2^2n,  		// up to a maximum possible backoff time of 5 minutes.  		// @@ -257,11 +326,37 @@ func retryOnBusy(ctx context.Context, fn func() error) error {  		select {  		// Context cancelled.  		case <-ctx.Done(): +			return ctx.Err()  		// Backoff for some time.  		case <-time.After(backoff):  		} + +		// Perform func. +		err := fn() + +		if err != errBusy { +			// May be nil, or may be +			// some other error, either +			// way return here. +			return err +		}  	}  	return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff)  } + +// toNamedValues converts older driver.Value types to driver.NamedValue types. +func toNamedValues(args []driver.Value) []driver.NamedValue { +	if args == nil { +		return nil +	} +	args2 := make([]driver.NamedValue, len(args)) +	for i := range args { +		args2[i] = driver.NamedValue{ +			Ordinal: i + 1, +			Value:   args[i], +		} +	} +	return args2 +}  | 
