summaryrefslogtreecommitdiff
path: root/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go')
-rw-r--r--vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go120
1 files changed, 120 insertions, 0 deletions
diff --git a/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go b/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go
new file mode 100644
index 000000000..e87a1e73f
--- /dev/null
+++ b/vendor/github.com/uptrace/opentelemetry-go-extra/otelsql/stmt.go
@@ -0,0 +1,120 @@
+package otelsql
+
+import (
+ "context"
+ "database/sql/driver"
+
+ "go.opentelemetry.io/otel/trace"
+)
+
+type otelStmt struct {
+ driver.Stmt
+
+ query string
+ instrum *dbInstrum
+
+ execCtx stmtExecCtxFunc
+ queryCtx stmtQueryCtxFunc
+}
+
+var _ driver.Stmt = (*otelStmt)(nil)
+
+func newStmt(stmt driver.Stmt, query string, instrum *dbInstrum) *otelStmt {
+ s := &otelStmt{
+ Stmt: stmt,
+ query: query,
+ instrum: instrum,
+ }
+ s.execCtx = s.createExecCtxFunc(stmt)
+ s.queryCtx = s.createQueryCtxFunc(stmt)
+ return s
+}
+
+//------------------------------------------------------------------------------
+
+var _ driver.StmtExecContext = (*otelStmt)(nil)
+
+func (stmt *otelStmt) ExecContext(
+ ctx context.Context, args []driver.NamedValue,
+) (driver.Result, error) {
+ return stmt.execCtx(ctx, args)
+}
+
+type stmtExecCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Result, error)
+
+func (s *otelStmt) createExecCtxFunc(stmt driver.Stmt) stmtExecCtxFunc {
+ var fn stmtExecCtxFunc
+
+ if execer, ok := s.Stmt.(driver.StmtExecContext); ok {
+ fn = execer.ExecContext
+ } else {
+ fn = func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+ vArgs, err := namedValueToValue(args)
+ if err != nil {
+ return nil, err
+ }
+ return stmt.Exec(vArgs)
+ }
+ }
+
+ return func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+ var res driver.Result
+ err := s.instrum.withSpan(ctx, "stmt.Exec", s.query,
+ func(ctx context.Context, span trace.Span) error {
+ var err error
+ res, err = fn(ctx, args)
+ if err != nil {
+ return err
+ }
+
+ if span.IsRecording() {
+ rows, err := res.RowsAffected()
+ if err == nil {
+ span.SetAttributes(dbRowsAffected.Int64(rows))
+ }
+ }
+
+ return nil
+ })
+ return res, err
+ }
+}
+
+//------------------------------------------------------------------------------
+
+var _ driver.StmtQueryContext = (*otelStmt)(nil)
+
+func (stmt *otelStmt) QueryContext(
+ ctx context.Context, args []driver.NamedValue,
+) (driver.Rows, error) {
+ return stmt.queryCtx(ctx, args)
+}
+
+type stmtQueryCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error)
+
+func (s *otelStmt) createQueryCtxFunc(stmt driver.Stmt) stmtQueryCtxFunc {
+ var fn stmtQueryCtxFunc
+
+ if queryer, ok := s.Stmt.(driver.StmtQueryContext); ok {
+ fn = queryer.QueryContext
+ } else {
+ fn = func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+ vArgs, err := namedValueToValue(args)
+ if err != nil {
+ return nil, err
+ }
+ return s.Query(vArgs)
+ }
+ }
+
+ return func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+ var rows driver.Rows
+ err := s.instrum.withSpan(ctx, "stmt.Query", s.query,
+ func(ctx context.Context, span trace.Span) error {
+ var err error
+ rows, err = fn(ctx, args)
+ return err
+ })
+ return rows, err
+ }
+}