diff options
Diffstat (limited to 'vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go')
-rw-r--r-- | vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go | 120 |
1 files changed, 120 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go b/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go new file mode 100644 index 000000000..e87a1e73f --- /dev/null +++ b/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go @@ -0,0 +1,120 @@ +package otelsql + +import ( + "context" + "database/sql/driver" + + "go.opentelemetry.io/otel/trace" +) + +type otelStmt struct { + driver.Stmt + + query string + instrum *dbInstrum + + execCtx stmtExecCtxFunc + queryCtx stmtQueryCtxFunc +} + +var _ driver.Stmt = (*otelStmt)(nil) + +func newStmt(stmt driver.Stmt, query string, instrum *dbInstrum) *otelStmt { + s := &otelStmt{ + Stmt: stmt, + query: query, + instrum: instrum, + } + s.execCtx = s.createExecCtxFunc(stmt) + s.queryCtx = s.createQueryCtxFunc(stmt) + return s +} + +//------------------------------------------------------------------------------ + +var _ driver.StmtExecContext = (*otelStmt)(nil) + +func (stmt *otelStmt) ExecContext( + ctx context.Context, args []driver.NamedValue, +) (driver.Result, error) { + return stmt.execCtx(ctx, args) +} + +type stmtExecCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) + +func (s *otelStmt) createExecCtxFunc(stmt driver.Stmt) stmtExecCtxFunc { + var fn stmtExecCtxFunc + + if execer, ok := s.Stmt.(driver.StmtExecContext); ok { + fn = execer.ExecContext + } else { + fn = func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + vArgs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + return stmt.Exec(vArgs) + } + } + + return func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + var res driver.Result + err := s.instrum.withSpan(ctx, "stmt.Exec", s.query, + func(ctx context.Context, span trace.Span) error { + var err error + res, err = fn(ctx, args) + if err != nil { + return err + } + + if span.IsRecording() { + rows, err := res.RowsAffected() + if err == nil { + span.SetAttributes(dbRowsAffected.Int64(rows)) + } + } + + return nil + }) + return res, err + } +} + +//------------------------------------------------------------------------------ + +var _ driver.StmtQueryContext = (*otelStmt)(nil) + +func (stmt *otelStmt) QueryContext( + ctx context.Context, args []driver.NamedValue, +) (driver.Rows, error) { + return stmt.queryCtx(ctx, args) +} + +type stmtQueryCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) + +func (s *otelStmt) createQueryCtxFunc(stmt driver.Stmt) stmtQueryCtxFunc { + var fn stmtQueryCtxFunc + + if queryer, ok := s.Stmt.(driver.StmtQueryContext); ok { + fn = queryer.QueryContext + } else { + fn = func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + vArgs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + return s.Query(vArgs) + } + } + + return func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + var rows driver.Rows + err := s.instrum.withSpan(ctx, "stmt.Query", s.query, + func(ctx context.Context, span trace.Span) error { + var err error + rows, err = fn(ctx, args) + return err + }) + return rows, err + } +} |