diff options
Diffstat (limited to 'vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/driver.go')
| -rw-r--r-- | vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/driver.go | 460 |
1 files changed, 460 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/driver.go b/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/driver.go new file mode 100644 index 000000000..056af3c6c --- /dev/null +++ b/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/driver.go @@ -0,0 +1,460 @@ +package otelsql + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + + "go.opentelemetry.io/otel/trace" +) + +// Open is a wrapper over sql.Open that instruments the sql.DB to record executed queries +// using OpenTelemetry API. +func Open(driverName, dsn string, opts ...Option) (*sql.DB, error) { + db, err := sql.Open(driverName, dsn) + if err != nil { + return nil, err + } + return patchDB(db, dsn, opts...) +} + +func patchDB(db *sql.DB, dsn string, opts ...Option) (*sql.DB, error) { + dbDriver := db.Driver() + d := newDriver(dbDriver, opts) + + if _, ok := dbDriver.(driver.DriverContext); ok { + connector, err := d.OpenConnector(dsn) + if err != nil { + return nil, err + } + return sqlOpenDB(connector, d.instrum), nil + } + + return sqlOpenDB(&dsnConnector{ + driver: d, + dsn: dsn, + }, d.instrum), nil +} + +// OpenDB is a wrapper over sql.OpenDB that instruments the sql.DB to record executed queries +// using OpenTelemetry API. +func OpenDB(connector driver.Connector, opts ...Option) *sql.DB { + instrum := newDBInstrum(opts) + c := newConnector(connector.Driver(), connector, instrum) + return sqlOpenDB(c, instrum) +} + +func sqlOpenDB(connector driver.Connector, instrum *dbInstrum) *sql.DB { + db := sql.OpenDB(connector) + ReportDBStatsMetrics(db, WithMeterProvider(instrum.meterProvider), WithAttributes(instrum.attrs...)) + return db +} + +type dsnConnector struct { + driver *otelDriver + dsn string +} + +func (c *dsnConnector) Connect(ctx context.Context) (driver.Conn, error) { + var conn driver.Conn + err := c.driver.instrum.withSpan(ctx, "db.Connect", "", + func(ctx context.Context, span trace.Span) error { + var err error + conn, err = c.driver.Open(c.dsn) + return err + }) + return conn, err +} + +func (c *dsnConnector) Driver() driver.Driver { + return c.driver +} + +//------------------------------------------------------------------------------ + +type otelDriver struct { + driver driver.Driver + driverCtx driver.DriverContext + instrum *dbInstrum +} + +var _ driver.DriverContext = (*otelDriver)(nil) + +func newDriver(dr driver.Driver, opts []Option) *otelDriver { + driverCtx, _ := dr.(driver.DriverContext) + d := &otelDriver{ + driver: dr, + driverCtx: driverCtx, + instrum: newDBInstrum(opts), + } + return d +} + +func (d *otelDriver) Open(name string) (driver.Conn, error) { + conn, err := d.driver.Open(name) + if err != nil { + return nil, err + } + return newConn(conn, d.instrum), nil +} + +func (d *otelDriver) OpenConnector(dsn string) (driver.Connector, error) { + connector, err := d.driverCtx.OpenConnector(dsn) + if err != nil { + return nil, err + } + return newConnector(d, connector, d.instrum), nil +} + +//------------------------------------------------------------------------------ + +type connector struct { + driver.Connector + driver driver.Driver + instrum *dbInstrum +} + +var _ driver.Connector = (*connector)(nil) + +func newConnector(d driver.Driver, c driver.Connector, instrum *dbInstrum) *connector { + return &connector{ + driver: d, + Connector: c, + instrum: instrum, + } +} + +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + var conn driver.Conn + if err := c.instrum.withSpan(ctx, "db.Connect", "", + func(ctx context.Context, span trace.Span) error { + var err error + conn, err = c.Connector.Connect(ctx) + return err + }); err != nil { + return nil, err + } + return newConn(conn, c.instrum), nil +} + +func (c *connector) Driver() driver.Driver { + return c.driver +} + +//------------------------------------------------------------------------------ + +type otelConn struct { + driver.Conn + + instrum *dbInstrum + + ping pingFunc + exec execFunc + execCtx execCtxFunc + query queryFunc + queryCtx queryCtxFunc + prepareCtx prepareCtxFunc + beginTx beginTxFunc + resetSession resetSessionFunc + checkNamedValue checkNamedValueFunc +} + +var _ driver.Conn = (*otelConn)(nil) + +func newConn(conn driver.Conn, instrum *dbInstrum) *otelConn { + cn := &otelConn{ + Conn: conn, + instrum: instrum, + } + + cn.ping = cn.createPingFunc(conn) + cn.exec = cn.createExecFunc(conn) + cn.execCtx = cn.createExecCtxFunc(conn) + cn.query = cn.createQueryFunc(conn) + cn.queryCtx = cn.createQueryCtxFunc(conn) + cn.prepareCtx = cn.createPrepareCtxFunc(conn) + cn.beginTx = cn.createBeginTxFunc(conn) + cn.resetSession = cn.createResetSessionFunc(conn) + cn.checkNamedValue = cn.createCheckNamedValueFunc(conn) + + return cn +} + +var _ driver.Pinger = (*otelConn)(nil) + +func (c *otelConn) Ping(ctx context.Context) error { + return c.ping(ctx) +} + +type pingFunc func(ctx context.Context) error + +func (c *otelConn) createPingFunc(conn driver.Conn) pingFunc { + if pinger, ok := conn.(driver.Pinger); ok { + return func(ctx context.Context) error { + return c.instrum.withSpan(ctx, "db.Ping", "", + func(ctx context.Context, span trace.Span) error { + return pinger.Ping(ctx) + }) + } + } + return func(ctx context.Context) error { + return driver.ErrSkip + } +} + +//------------------------------------------------------------------------------ + +var _ driver.Execer = (*otelConn)(nil) + +func (c *otelConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return c.exec(query, args) +} + +type execFunc func(query string, args []driver.Value) (driver.Result, error) + +func (c *otelConn) createExecFunc(conn driver.Conn) execFunc { + if execer, ok := conn.(driver.Execer); ok { + return func(query string, args []driver.Value) (driver.Result, error) { + return execer.Exec(query, args) + } + } + return func(query string, args []driver.Value) (driver.Result, error) { + return nil, driver.ErrSkip + } +} + +//------------------------------------------------------------------------------ + +var _ driver.ExecerContext = (*otelConn)(nil) + +func (c *otelConn) ExecContext( + ctx context.Context, query string, args []driver.NamedValue, +) (driver.Result, error) { + return c.execCtx(ctx, query, args) +} + +type execCtxFunc func(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) + +func (c *otelConn) createExecCtxFunc(conn driver.Conn) execCtxFunc { + var fn execCtxFunc + + if execer, ok := conn.(driver.ExecerContext); ok { + fn = execer.ExecContext + } else { + fn = func(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + vArgs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + return c.exec(query, vArgs) + } + } + + return func(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + var res driver.Result + if err := c.instrum.withSpan(ctx, "db.Exec", query, + func(ctx context.Context, span trace.Span) error { + var err error + res, err = fn(ctx, query, args) + if err != nil { + return err + } + + if span.IsRecording() { + rows, err := res.RowsAffected() + if err == nil { + span.SetAttributes(dbRowsAffected.Int64(rows)) + } + } + + return nil + }); err != nil { + return nil, err + } + return res, nil + } +} + +//------------------------------------------------------------------------------ + +var _ driver.Queryer = (*otelConn)(nil) + +func (c *otelConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return c.query(query, args) +} + +type queryFunc func(query string, args []driver.Value) (driver.Rows, error) + +func (c *otelConn) createQueryFunc(conn driver.Conn) queryFunc { + if queryer, ok := c.Conn.(driver.Queryer); ok { + return func(query string, args []driver.Value) (driver.Rows, error) { + return queryer.Query(query, args) + } + } + return func(query string, args []driver.Value) (driver.Rows, error) { + return nil, driver.ErrSkip + } +} + +//------------------------------------------------------------------------------ + +var _ driver.QueryerContext = (*otelConn)(nil) + +func (c *otelConn) QueryContext( + ctx context.Context, query string, args []driver.NamedValue, +) (driver.Rows, error) { + return c.queryCtx(ctx, query, args) +} + +type queryCtxFunc func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) + +func (c *otelConn) createQueryCtxFunc(conn driver.Conn) queryCtxFunc { + var fn queryCtxFunc + + if queryer, ok := c.Conn.(driver.QueryerContext); ok { + fn = queryer.QueryContext + } else { + fn = func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + vArgs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + return c.query(query, vArgs) + } + } + + return func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + var rows driver.Rows + err := c.instrum.withSpan(ctx, "db.Query", query, + func(ctx context.Context, span trace.Span) error { + var err error + rows, err = fn(ctx, query, args) + return err + }) + return rows, err + } +} + +//------------------------------------------------------------------------------ + +var _ driver.ConnPrepareContext = (*otelConn)(nil) + +func (c *otelConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + return c.prepareCtx(ctx, query) +} + +type prepareCtxFunc func(ctx context.Context, query string) (driver.Stmt, error) + +func (c *otelConn) createPrepareCtxFunc(conn driver.Conn) prepareCtxFunc { + var fn prepareCtxFunc + + if preparer, ok := c.Conn.(driver.ConnPrepareContext); ok { + fn = preparer.PrepareContext + } else { + fn = func(ctx context.Context, query string) (driver.Stmt, error) { + return c.Conn.Prepare(query) + } + } + + return func(ctx context.Context, query string) (driver.Stmt, error) { + var stmt driver.Stmt + if err := c.instrum.withSpan(ctx, "db.Prepare", query, + func(ctx context.Context, span trace.Span) error { + var err error + stmt, err = fn(ctx, query) + return err + }); err != nil { + return nil, err + } + return newStmt(stmt, query, c.instrum), nil + } +} + +var _ driver.ConnBeginTx = (*otelConn)(nil) + +func (c *otelConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return c.beginTx(ctx, opts) +} + +type beginTxFunc func(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) + +func (c *otelConn) createBeginTxFunc(conn driver.Conn) beginTxFunc { + var fn beginTxFunc + + if txor, ok := conn.(driver.ConnBeginTx); ok { + fn = txor.BeginTx + } else { + fn = func(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return conn.Begin() + } + } + + return func(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var tx driver.Tx + if err := c.instrum.withSpan(ctx, "db.Begin", "", + func(ctx context.Context, span trace.Span) error { + var err error + tx, err = fn(ctx, opts) + return err + }); err != nil { + return nil, err + } + return newTx(ctx, tx, c.instrum), nil + } +} + +//------------------------------------------------------------------------------ + +var _ driver.SessionResetter = (*otelConn)(nil) + +func (c *otelConn) ResetSession(ctx context.Context) error { + return c.resetSession(ctx) +} + +type resetSessionFunc func(ctx context.Context) error + +func (c *otelConn) createResetSessionFunc(conn driver.Conn) resetSessionFunc { + if resetter, ok := c.Conn.(driver.SessionResetter); ok { + return func(ctx context.Context) error { + return resetter.ResetSession(ctx) + } + } + return func(ctx context.Context) error { + return driver.ErrSkip + } +} + +//------------------------------------------------------------------------------ + +var _ driver.NamedValueChecker = (*otelConn)(nil) + +func (c *otelConn) CheckNamedValue(value *driver.NamedValue) error { + return c.checkNamedValue(value) +} + +type checkNamedValueFunc func(*driver.NamedValue) error + +func (c *otelConn) createCheckNamedValueFunc(conn driver.Conn) checkNamedValueFunc { + if checker, ok := c.Conn.(driver.NamedValueChecker); ok { + return func(value *driver.NamedValue) error { + return checker.CheckNamedValue(value) + } + } + return func(value *driver.NamedValue) error { + return driver.ErrSkip + } +} + +//------------------------------------------------------------------------------ + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + args := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + return nil, errors.New("otelsql: driver does not support named parameters") + } + args[n] = param.Value + } + return args, nil +} |
