diff options
author | 2021-08-25 15:34:33 +0200 | |
---|---|---|
committer | 2021-08-25 15:34:33 +0200 | |
commit | 2dc9fc1626507bb54417fc4a1920b847cafb27a2 (patch) | |
tree | 4ddeac479b923db38090aac8bd9209f3646851c1 /vendor/github.com | |
parent | Manually approves followers (#146) (diff) | |
download | gotosocial-2dc9fc1626507bb54417fc4a1920b847cafb27a2.tar.xz |
Pg to bun (#148)
* start moving to bun
* changing more stuff
* more
* and yet more
* tests passing
* seems stable now
* more big changes
* small fix
* little fixes
Diffstat (limited to 'vendor/github.com')
403 files changed, 52621 insertions, 18850 deletions
diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/go.mod b/vendor/github.com/go-pg/pg/extra/pgdebug/go.mod deleted file mode 100644 index d44ba0123..000000000 --- a/vendor/github.com/go-pg/pg/extra/pgdebug/go.mod +++ /dev/null @@ -1,7 +0,0 @@ -module github.com/go-pg/pg/extra/pgdebug - -go 1.15 - -replace github.com/go-pg/pg/v10 => ../.. - -require github.com/go-pg/pg/v10 v10.6.2 diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/go.sum b/vendor/github.com/go-pg/pg/extra/pgdebug/go.sum deleted file mode 100644 index 8483a864a..000000000 --- a/vendor/github.com/go-pg/pg/extra/pgdebug/go.sum +++ /dev/null @@ -1,161 +0,0 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= -github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU= -github.com/go-pg/zerochecker v0.2.0/go.mod h1:NJZ4wKL0NmTtz0GKCoJ8kym6Xn/EQzXRl2OnAe7MmDo= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.3 h1:x95R7cp+rSeeqAMI2knLtQ0DKlaBhv2NrtrOvafPHRo= -github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= -github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M= -github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= -github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.10.3 h1:gph6h/qe9GSUw1NhH1gp+qb+h8rXD8Cy60Z32Qw3ELA= -github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= -github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= -github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94= -github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ= -github.com/vmihailenco/msgpack/v4 v4.3.11/go.mod h1:gborTTJjAo/GWTqqRjrLCn9pgNN+NXzzngzBKDPIqw4= -github.com/vmihailenco/msgpack/v5 v5.0.0 h1:nCaMMPEyfgwkGc/Y0GreJPhuvzqCqW+Ufq5lY7zLO2c= -github.com/vmihailenco/msgpack/v5 v5.0.0/go.mod h1:HVxBVPUK/+fZMonk4bi1islLa8V3cfnBug0+4dykPzo= -github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= -github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc= -github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= -go.opentelemetry.io/otel v0.14.0 h1:YFBEfjCk9MTjaytCNSUkp9Q8lF7QJezA06T71FbQxLQ= -go.opentelemetry.io/otel v0.14.0/go.mod h1:vH5xEuwy7Rts0GNtsCW3HYQoZDY+OmBJ6t1bFGGlxgw= -golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 h1:phUcVbl53swtrUN8kQEXFhUxPlIlWyBfKmidCu7P95o= -golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME= -golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w= -mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ= diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/pgdebug.go b/vendor/github.com/go-pg/pg/extra/pgdebug/pgdebug.go deleted file mode 100644 index bbf6ada19..000000000 --- a/vendor/github.com/go-pg/pg/extra/pgdebug/pgdebug.go +++ /dev/null @@ -1,42 +0,0 @@ -package pgdebug - -import ( - "context" - "fmt" - - "github.com/go-pg/pg/v10" -) - -// DebugHook is a query hook that logs an error with a query if there are any. -// It can be installed with: -// -// db.AddQueryHook(pgext.DebugHook{}) -type DebugHook struct { - // Verbose causes hook to print all queries (even those without an error). - Verbose bool - EmptyLine bool -} - -var _ pg.QueryHook = (*DebugHook)(nil) - -func (h DebugHook) BeforeQuery(ctx context.Context, evt *pg.QueryEvent) (context.Context, error) { - q, err := evt.FormattedQuery() - if err != nil { - return nil, err - } - - if evt.Err != nil { - fmt.Printf("%s executing a query:\n%s\n", evt.Err, q) - } else if h.Verbose { - if h.EmptyLine { - fmt.Println() - } - fmt.Println(string(q)) - } - - return ctx, nil -} - -func (DebugHook) AfterQuery(context.Context, *pg.QueryEvent) error { - return nil -} diff --git a/vendor/github.com/go-pg/pg/v10/.golangci.yml b/vendor/github.com/go-pg/pg/v10/.golangci.yml deleted file mode 100644 index e2b5ce924..000000000 --- a/vendor/github.com/go-pg/pg/v10/.golangci.yml +++ /dev/null @@ -1,18 +0,0 @@ -run: - concurrency: 8 - deadline: 5m - tests: false -linters: - enable-all: true - disable: - - gochecknoglobals - - gocognit - - gomnd - - wsl - - funlen - - godox - - goerr113 - - exhaustive - - nestif - - gofumpt - - goconst diff --git a/vendor/github.com/go-pg/pg/v10/.travis.yml b/vendor/github.com/go-pg/pg/v10/.travis.yml deleted file mode 100644 index 6db22a449..000000000 --- a/vendor/github.com/go-pg/pg/v10/.travis.yml +++ /dev/null @@ -1,21 +0,0 @@ -dist: xenial -language: go - -addons: - postgresql: '9.6' - -go: - - 1.14.x - - 1.15.x - - tip - -matrix: - allow_failures: - - go: tip - -go_import_path: github.com/go-pg/pg - -before_install: - - psql -U postgres -c "CREATE EXTENSION hstore" - - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- - -b $(go env GOPATH)/bin v1.28.3 diff --git a/vendor/github.com/go-pg/pg/v10/CHANGELOG.md b/vendor/github.com/go-pg/pg/v10/CHANGELOG.md deleted file mode 100644 index 6a8288033..000000000 --- a/vendor/github.com/go-pg/pg/v10/CHANGELOG.md +++ /dev/null @@ -1,204 +0,0 @@ -# Changelog - -> :heart: -> [**Uptrace.dev** - All-in-one tool to optimize performance and monitor errors & logs](https://uptrace.dev) - -**Important**. Please check [Bun](https://bun.uptrace.dev/guide/pg-migration.html) - the next -iteration of go-pg built on top of `sql.DB`. - -## v10.10 - -- Removed extra OpenTelemetry spans from go-pg core. Now go-pg instrumentation only adds a single - span with a SQL query (instead of 4 spans). There are multiple reasons behind this decision: - - - Traces become smaller and less noisy. - - [Bun](https://github.com/uptrace/bun) can't support the same level of instrumentation and it is - nice to keep the projects synced. - - It may be costly to process those 3 extra spans for each query. - - Eventually we hope to replace the information that we no longer collect with OpenTelemetry - Metrics. - -## v10.9 - -- To make updating easier, extra modules now have the same version as go-pg does. That means that - you need to update your imports: - -``` -github.com/go-pg/pg/extra/pgdebug -> github.com/go-pg/pg/extra/pgdebug/v10 -github.com/go-pg/pg/extra/pgotel -> github.com/go-pg/pg/extra/pgotel/v10 -github.com/go-pg/pg/extra/pgsegment -> github.com/go-pg/pg/extra/pgsegment/v10 -``` - -- Exported `pg.Query` which should be used instead of `orm.Query`. -- Added `pg.DBI` which is a DB interface implemented by `pg.DB` and `pg.Tx`. - -## v10 - -### Resources - -- Docs at https://pg.uptrace.dev/ powered by [mkdocs](https://github.com/squidfunk/mkdocs-material). -- [RealWorld example application](https://github.com/uptrace/go-realworld-example-app). -- [Discord](https://discord.gg/rWtp5Aj). - -### Features - -- `Select`, `Insert`, and `Update` support `map[string]interface{}`. `Select` also supports - `[]map[string]interface{}`. - -```go -var mm []map[string]interface{} -err := db.Model((*User)(nil)).Limit(10).Select(&mm) -``` - -- Columns that start with `_` are ignored if there is no destination field. -- Optional [faster json encoding](https://github.com/go-pg/pgext). -- Added [pgext.OpenTelemetryHook](https://github.com/go-pg/pgext) that adds - [OpenTelemetry instrumentation](https://pg.uptrace.dev/tracing/). -- Added [pgext.DebugHook](https://github.com/go-pg/pgext) that logs failed queries. -- Added `db.Ping` to check if database is healthy. - -### Changes - -- ORM relations are reworked and now require `rel` tag option (but existing code will continue - working until v11). Supported options: - - `pg:"rel:has-one"` - has one relation. - - `pg:"rel:belongs-to"` - belongs to relation. - - `pg:"rel:has-many"` - has many relation. - - `pg:"many2many:book_genres"` - many to many relation. -- Changed `pg.QueryHook` to return temp byte slice to reduce memory usage. -- `,msgpack` struct tag marshals data in MessagePack format using - https://github.com/vmihailenco/msgpack -- Empty slices and maps are no longer marshaled as `NULL`. Nil slices and maps are still marshaled - as `NULL`. -- Changed `UpdateNotZero` to include zero fields with `pg:",use_zero"` tag. Consider using - `Model(*map[string]interface{})` for inserts and updates. -- `joinFK` is deprecated in favor of `join_fk`. -- `partitionBy` is deprecated in favor of `partition_by`. -- ORM shortcuts are removed: - - `db.Select(model)` becomes `db.Model(model).WherePK().Select()`. - - `db.Insert(model)` becomes `db.Model(model).Insert()`. - - `db.Update(model)` becomes `db.Model(model).WherePK().Update()`. - - `db.Delete(model)` becomes `db.Model(model).WherePK().Delete()`. -- Deprecated types and funcs are removed. -- `WhereStruct` is removed. - -## v9 - -- `pg:",notnull"` is reworked. Now it means SQL `NOT NULL` constraint and nothing more. -- Added `pg:",use_zero"` to prevent go-pg from converting Go zero values to SQL `NULL`. -- UpdateNotNull is renamed to UpdateNotZero. As previously it omits zero Go values, but it does not - take in account if field is nullable or not. -- ORM supports DistinctOn. -- Hooks accept and return context. -- Client respects Context.Deadline when setting net.Conn deadline. -- Client listens on Context.Done while waiting for a connection from the pool and returns an error - when context is cancelled. -- `Query.Column` does not accept relation name any more. Use `Query.Relation` instead which returns - an error if relation does not exist. -- urlvalues package is removed in favor of https://github.com/go-pg/urlstruct. You can also use - struct based filters via `Query.WhereStruct`. -- `NewModel` and `AddModel` methods of `HooklessModel` interface were renamed to `NextColumnScanner` - and `AddColumnScanner` respectively. -- `types.F` and `pg.F` are deprecated in favor of `pg.Ident`. -- `types.Q` is deprecated in favor of `pg.Safe`. -- `pg.Q` is deprecated in favor of `pg.SafeQuery`. -- `TableName` field is deprecated in favor of `tableName`. -- Always use `pg:"..."` struct field tag instead of `sql:"..."`. -- `pg:",override"` is deprecated in favor of `pg:",inherit"`. - -## v8 - -- Added `QueryContext`, `ExecContext`, and `ModelContext` which accept `context.Context`. Queries - are cancelled when context is cancelled. -- Model hooks are changed to accept `context.Context` as first argument. -- Fixed array and hstore parsers to handle multiple single quotes (#1235). - -## v7 - -- DB.OnQueryProcessed is replaced with DB.AddQueryHook. -- Added WhereStruct. -- orm.Pager is moved to urlvalues.Pager. Pager.FromURLValues returns an error if page or limit - params can't be parsed. - -## v6.16 - -- Read buffer is re-worked. Default read buffer is increased to 65kb. - -## v6.15 - -- Added Options.MinIdleConns. -- Options.MaxAge renamed to Options.MaxConnAge. -- PoolStats.FreeConns is renamed to PoolStats.IdleConns. -- New hook BeforeSelectQuery. -- `,override` is renamed to `,inherit`. -- Dialer.KeepAlive is set to 5 minutes by default. -- Added support "scram-sha-256" authentication. - -## v6.14 - -- Fields ignored with `sql:"-"` tag are no longer considered by ORM relation detector. - -## v6.12 - -- `Insert`, `Update`, and `Delete` can return `pg.ErrNoRows` and `pg.ErrMultiRows` when `Returning` - is used and model expects single row. - -## v6.11 - -- `db.Model(&strct).Update()` and `db.Model(&strct).Delete()` no longer adds WHERE condition based - on primary key when there are no conditions. Instead you should use `db.Update(&strct)` or - `db.Model(&strct).WherePK().Update()`. - -## v6.10 - -- `?Columns` is renamed to `?TableColumns`. `?Columns` is changed to produce column names without - table alias. - -## v6.9 - -- `pg:"fk"` tag now accepts SQL names instead of Go names, e.g. `pg:"fk:ParentId"` becomes - `pg:"fk:parent_id"`. Old code should continue working in most cases, but it is strongly advised to - start using new convention. -- uint and uint64 SQL type is changed from decimal to bigint according to the lesser of two evils - principle. Use `sql:"type:decimal"` to get old behavior. - -## v6.8 - -- `CreateTable` no longer adds ON DELETE hook by default. To get old behavior users should add - `sql:"on_delete:CASCADE"` tag on foreign key field. - -## v6 - -- `types.Result` is renamed to `orm.Result`. -- Added `OnQueryProcessed` event that can be used to log / report queries timing. Query logger is - removed. -- `orm.URLValues` is renamed to `orm.URLFilters`. It no longer adds ORDER clause. -- `orm.Pager` is renamed to `orm.Pagination`. -- Support for net.IP and net.IPNet. -- Support for context.Context. -- Bulk/multi updates. -- Query.WhereGroup for enclosing conditions in parentheses. - -## v5 - -- All fields are nullable by default. `,null` tag is replaced with `,notnull`. -- `Result.Affected` renamed to `Result.RowsAffected`. -- Added `Result.RowsReturned`. -- `Create` renamed to `Insert`, `BeforeCreate` to `BeforeInsert`, `AfterCreate` to `AfterInsert`. -- Indexed placeholders support, e.g. `db.Exec("SELECT ?0 + ?0", 1)`. -- Named placeholders are evaluated when query is executed. -- Added Update and Delete hooks. -- Order reworked to quote column names. OrderExpr added to bypass Order quoting restrictions. -- Group reworked to quote column names. GroupExpr added to bypass Group quoting restrictions. - -## v4 - -- `Options.Host` and `Options.Port` merged into `Options.Addr`. -- Added `Options.MaxRetries`. Now queries are not retried by default. -- `LoadInto` renamed to `Scan`, `ColumnLoader` renamed to `ColumnScanner`, LoadColumn renamed to - ScanColumn, `NewRecord() interface{}` changed to `NewModel() ColumnScanner`, - `AppendQuery(dst []byte) []byte` changed to `AppendValue(dst []byte, quote bool) ([]byte, error)`. -- Structs, maps and slices are marshalled to JSON by default. -- Added support for scanning slices, .e.g. scanning `[]int`. -- Added object relational mapping. diff --git a/vendor/github.com/go-pg/pg/v10/Makefile b/vendor/github.com/go-pg/pg/v10/Makefile deleted file mode 100644 index bacdbadae..000000000 --- a/vendor/github.com/go-pg/pg/v10/Makefile +++ /dev/null @@ -1,27 +0,0 @@ -all: - TZ= go test ./... - TZ= go test ./... -short -race - TZ= go test ./... -run=NONE -bench=. -benchmem - env GOOS=linux GOARCH=386 go test ./... - go vet - golangci-lint run - -.PHONY: cleanTest -cleanTest: - docker rm -fv pg || true - -.PHONY: pre-test -pre-test: cleanTest - docker run -d --name pg -p 5432:5432 -e POSTGRES_HOST_AUTH_METHOD=trust postgres:9.6 - sleep 10 - docker exec pg psql -U postgres -c "CREATE EXTENSION hstore" - -.PHONY: test -test: pre-test - TZ= PGSSLMODE=disable go test ./... -v - -tag: - git tag $(VERSION) - git tag extra/pgdebug/$(VERSION) - git tag extra/pgotel/$(VERSION) - git tag extra/pgsegment/$(VERSION) diff --git a/vendor/github.com/go-pg/pg/v10/README.md b/vendor/github.com/go-pg/pg/v10/README.md deleted file mode 100644 index a624e0b8e..000000000 --- a/vendor/github.com/go-pg/pg/v10/README.md +++ /dev/null @@ -1,240 +0,0 @@ -<p align="center"> - <a href="https://uptrace.dev/?utm_source=gh-pg&utm_campaign=gh-pg-banner1"> - <img src="https://raw.githubusercontent.com/uptrace/roadmap/master/banner1.png"> - </a> -</p> - -# PostgreSQL client and ORM for Golang - -[](https://travis-ci.org/go-pg/pg) -[](https://pkg.go.dev/github.com/go-pg/pg/v10) -[](https://pg.uptrace.dev/) -[](https://discord.gg/rWtp5Aj) - -**Important**. Please check [Bun](https://bun.uptrace.dev/guide/pg-migration.html) - the next -iteration of go-pg built on top of `sql.DB`. - -- Join [Discord](https://discord.gg/rWtp5Aj) to ask questions. -- [Documentation](https://pg.uptrace.dev) -- [Reference](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc) -- [Examples](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#pkg-examples) -- Example projects: - - [treemux](https://github.com/uptrace/go-treemux-realworld-example-app) - - [gin](https://github.com/gogjango/gjango) - - [go-kit](https://github.com/Tsovak/rest-api-demo) - - [aah framework](https://github.com/kieusonlam/golamapi) -- [GraphQL Tutorial on YouTube](https://www.youtube.com/playlist?list=PLzQWIQOqeUSNwXcneWYJHUREAIucJ5UZn). - -## Ecosystem - -- Migrations by [vmihailenco](https://github.com/go-pg/migrations) and - [robinjoseph08](https://github.com/robinjoseph08/go-pg-migrations). -- [Genna - cli tool for generating go-pg models](https://github.com/dizzyfool/genna). -- [bigint](https://github.com/d-fal/bigint) - big.Int type for go-pg. -- [urlstruct](https://github.com/go-pg/urlstruct) to decode `url.Values` into structs. -- [Sharding](https://github.com/go-pg/sharding). -- [go-pg-monitor](https://github.com/hypnoglow/go-pg-monitor) - Prometheus metrics based on go-pg - client stats. - -## Features - -- Basic types: integers, floats, string, bool, time.Time, net.IP, net.IPNet. -- sql.NullBool, sql.NullString, sql.NullInt64, sql.NullFloat64 and - [pg.NullTime](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#NullTime). -- [sql.Scanner](http://golang.org/pkg/database/sql/#Scanner) and - [sql/driver.Valuer](http://golang.org/pkg/database/sql/driver/#Valuer) interfaces. -- Structs, maps and arrays are marshalled as JSON by default. -- PostgreSQL multidimensional Arrays using - [array tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-PostgresArrayStructTag) - and [Array wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Array). -- Hstore using - [hstore tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HstoreStructTag) - and [Hstore wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Hstore). -- [Composite types](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CompositeType). -- All struct fields are nullable by default and zero values (empty string, 0, zero time, empty map - or slice, nil ptr) are marshalled as SQL `NULL`. `pg:",notnull"` is used to add SQL `NOT NULL` - constraint and `pg:",use_zero"` to allow Go zero values. -- [Transactions](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Begin). -- [Prepared statements](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Prepare). -- [Notifications](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Listener) using - `LISTEN` and `NOTIFY`. -- [Copying data](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CopyFrom) using - `COPY FROM` and `COPY TO`. -- [Timeouts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#Options) and canceling queries using - context.Context. -- Automatic connection pooling with - [circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support. -- Queries retry on network errors. -- Working with models using - [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model) and - [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Query). -- Scanning variables using - [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectSomeColumnsIntoVars) - and [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Scan). -- [SelectOrInsert](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertSelectOrInsert) - using on-conflict. -- [INSERT ... ON CONFLICT DO UPDATE](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertOnConflictDoUpdate) - using ORM. -- Bulk/batch - [inserts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkInsert), - [updates](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkUpdate), and - [deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkDelete). -- Common table expressions using - [WITH](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWith) and - [WrapWith](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWrapWith). -- [CountEstimate](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CountEstimate) - using `EXPLAIN` to get - [estimated number of matching rows](https://wiki.postgresql.org/wiki/Count_estimate). -- ORM supports - [has one](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasOne), - [belongs to](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BelongsTo), - [has many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasMany), and - [many to many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ManyToMany) - with composite/multi-column primary keys. -- [Soft deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SoftDelete). -- [Creating tables from structs](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CreateTable). -- [ForEach](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ForEach) that calls - a function for each row returned by the query without loading all rows into the memory. - -## Installation - -go-pg supports 2 last Go versions and requires a Go version with -[modules](https://github.com/golang/go/wiki/Modules) support. So make sure to initialize a Go -module: - -```shell -go mod init github.com/my/repo -``` - -And then install go-pg (note _v10_ in the import; omitting it is a popular mistake): - -```shell -go get github.com/go-pg/pg/v10 -``` - -## Quickstart - -```go -package pg_test - -import ( - "fmt" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" -) - -type User struct { - Id int64 - Name string - Emails []string -} - -func (u User) String() string { - return fmt.Sprintf("User<%d %s %v>", u.Id, u.Name, u.Emails) -} - -type Story struct { - Id int64 - Title string - AuthorId int64 - Author *User `pg:"rel:has-one"` -} - -func (s Story) String() string { - return fmt.Sprintf("Story<%d %s %s>", s.Id, s.Title, s.Author) -} - -func ExampleDB_Model() { - db := pg.Connect(&pg.Options{ - User: "postgres", - }) - defer db.Close() - - err := createSchema(db) - if err != nil { - panic(err) - } - - user1 := &User{ - Name: "admin", - Emails: []string{"admin1@admin", "admin2@admin"}, - } - _, err = db.Model(user1).Insert() - if err != nil { - panic(err) - } - - _, err = db.Model(&User{ - Name: "root", - Emails: []string{"root1@root", "root2@root"}, - }).Insert() - if err != nil { - panic(err) - } - - story1 := &Story{ - Title: "Cool story", - AuthorId: user1.Id, - } - _, err = db.Model(story1).Insert() - if err != nil { - panic(err) - } - - // Select user by primary key. - user := &User{Id: user1.Id} - err = db.Model(user).WherePK().Select() - if err != nil { - panic(err) - } - - // Select all users. - var users []User - err = db.Model(&users).Select() - if err != nil { - panic(err) - } - - // Select story and associated author in one query. - story := new(Story) - err = db.Model(story). - Relation("Author"). - Where("story.id = ?", story1.Id). - Select() - if err != nil { - panic(err) - } - - fmt.Println(user) - fmt.Println(users) - fmt.Println(story) - // Output: User<1 admin [admin1@admin admin2@admin]> - // [User<1 admin [admin1@admin admin2@admin]> User<2 root [root1@root root2@root]>] - // Story<1 Cool story User<1 admin [admin1@admin admin2@admin]>> -} - -// createSchema creates database schema for User and Story models. -func createSchema(db *pg.DB) error { - models := []interface{}{ - (*User)(nil), - (*Story)(nil), - } - - for _, model := range models { - err := db.Model(model).CreateTable(&orm.CreateTableOptions{ - Temp: true, - }) - if err != nil { - return err - } - } - return nil -} -``` - -## See also - -- [Fast and flexible HTTP router](https://github.com/vmihailenco/treemux) -- [Golang msgpack](https://github.com/vmihailenco/msgpack) -- [Golang message task queue](https://github.com/vmihailenco/taskq) diff --git a/vendor/github.com/go-pg/pg/v10/base.go b/vendor/github.com/go-pg/pg/v10/base.go deleted file mode 100644 index d13997464..000000000 --- a/vendor/github.com/go-pg/pg/v10/base.go +++ /dev/null @@ -1,618 +0,0 @@ -package pg - -import ( - "context" - "io" - "time" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/internal/pool" - "github.com/go-pg/pg/v10/orm" - "github.com/go-pg/pg/v10/types" -) - -type baseDB struct { - db orm.DB - opt *Options - pool pool.Pooler - - fmter *orm.Formatter - queryHooks []QueryHook -} - -// PoolStats contains the stats of a connection pool. -type PoolStats pool.Stats - -// PoolStats returns connection pool stats. -func (db *baseDB) PoolStats() *PoolStats { - stats := db.pool.Stats() - return (*PoolStats)(stats) -} - -func (db *baseDB) clone() *baseDB { - return &baseDB{ - db: db.db, - opt: db.opt, - pool: db.pool, - - fmter: db.fmter, - queryHooks: copyQueryHooks(db.queryHooks), - } -} - -func (db *baseDB) withPool(p pool.Pooler) *baseDB { - cp := db.clone() - cp.pool = p - return cp -} - -func (db *baseDB) WithTimeout(d time.Duration) *baseDB { - newopt := *db.opt - newopt.ReadTimeout = d - newopt.WriteTimeout = d - - cp := db.clone() - cp.opt = &newopt - return cp -} - -func (db *baseDB) WithParam(param string, value interface{}) *baseDB { - cp := db.clone() - cp.fmter = db.fmter.WithParam(param, value) - return cp -} - -// Param returns value for the param. -func (db *baseDB) Param(param string) interface{} { - return db.fmter.Param(param) -} - -func (db *baseDB) retryBackoff(retry int) time.Duration { - return internal.RetryBackoff(retry, db.opt.MinRetryBackoff, db.opt.MaxRetryBackoff) -} - -func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) { - cn, err := db.pool.Get(ctx) - if err != nil { - return nil, err - } - - if cn.Inited { - return cn, nil - } - - if err := db.initConn(ctx, cn); err != nil { - db.pool.Remove(ctx, cn, err) - // It is safe to reset StickyConnPool if conn can't be initialized. - if p, ok := db.pool.(*pool.StickyConnPool); ok { - _ = p.Reset(ctx) - } - if err := internal.Unwrap(err); err != nil { - return nil, err - } - return nil, err - } - - return cn, nil -} - -func (db *baseDB) initConn(ctx context.Context, cn *pool.Conn) error { - if cn.Inited { - return nil - } - cn.Inited = true - - if db.opt.TLSConfig != nil { - err := db.enableSSL(ctx, cn, db.opt.TLSConfig) - if err != nil { - return err - } - } - - err := db.startup(ctx, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName) - if err != nil { - return err - } - - if db.opt.OnConnect != nil { - p := pool.NewSingleConnPool(db.pool, cn) - return db.opt.OnConnect(ctx, newConn(ctx, db.withPool(p))) - } - - return nil -} - -func (db *baseDB) releaseConn(ctx context.Context, cn *pool.Conn, err error) { - if isBadConn(err, false) { - db.pool.Remove(ctx, cn, err) - } else { - db.pool.Put(ctx, cn) - } -} - -func (db *baseDB) withConn( - ctx context.Context, fn func(context.Context, *pool.Conn) error, -) error { - cn, err := db.getConn(ctx) - if err != nil { - return err - } - - var fnDone chan struct{} - if ctx != nil && ctx.Done() != nil { - fnDone = make(chan struct{}) - go func() { - select { - case <-fnDone: // fn has finished, skip cancel - case <-ctx.Done(): - err := db.cancelRequest(cn.ProcessID, cn.SecretKey) - if err != nil { - internal.Logger.Printf(ctx, "cancelRequest failed: %s", err) - } - // Signal end of conn use. - fnDone <- struct{}{} - } - }() - } - - defer func() { - if fnDone == nil { - db.releaseConn(ctx, cn, err) - return - } - - select { - case <-fnDone: // wait for cancel to finish request - // Looks like the canceled connection must be always removed from the pool. - db.pool.Remove(ctx, cn, err) - case fnDone <- struct{}{}: // signal fn finish, skip cancel goroutine - db.releaseConn(ctx, cn, err) - } - }() - - err = fn(ctx, cn) - return err -} - -func (db *baseDB) shouldRetry(err error) bool { - switch err { - case io.EOF, io.ErrUnexpectedEOF: - return true - case nil, context.Canceled, context.DeadlineExceeded: - return false - } - - if pgerr, ok := err.(Error); ok { - switch pgerr.Field('C') { - case "40001", // serialization_failure - "53300", // too_many_connections - "55000": // attempted to delete invisible tuple - return true - case "57014": // statement_timeout - return db.opt.RetryStatementTimeout - default: - return false - } - } - - if _, ok := err.(timeoutError); ok { - return true - } - - return false -} - -// Close closes the database client, releasing any open resources. -// -// It is rare to Close a DB, as the DB handle is meant to be -// long-lived and shared between many goroutines. -func (db *baseDB) Close() error { - return db.pool.Close() -} - -// Exec executes a query ignoring returned rows. The params are for any -// placeholders in the query. -func (db *baseDB) Exec(query interface{}, params ...interface{}) (res Result, err error) { - return db.exec(db.db.Context(), query, params...) -} - -func (db *baseDB) ExecContext(c context.Context, query interface{}, params ...interface{}) (Result, error) { - return db.exec(c, query, params...) -} - -func (db *baseDB) exec(ctx context.Context, query interface{}, params ...interface{}) (Result, error) { - wb := pool.GetWriteBuffer() - defer pool.PutWriteBuffer(wb) - - if err := writeQueryMsg(wb, db.fmter, query, params...); err != nil { - return nil, err - } - - ctx, evt, err := db.beforeQuery(ctx, db.db, nil, query, params, wb.Query()) - if err != nil { - return nil, err - } - - var res Result - var lastErr error - for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ { - if attempt > 0 { - if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil { - return nil, err - } - } - - lastErr = db.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - res, err = db.simpleQuery(ctx, cn, wb) - return err - }) - if !db.shouldRetry(lastErr) { - break - } - } - - if err := db.afterQuery(ctx, evt, res, lastErr); err != nil { - return nil, err - } - return res, lastErr -} - -// ExecOne acts like Exec, but query must affect only one row. It -// returns ErrNoRows error when query returns zero rows or -// ErrMultiRows when query returns multiple rows. -func (db *baseDB) ExecOne(query interface{}, params ...interface{}) (Result, error) { - return db.execOne(db.db.Context(), query, params...) -} - -func (db *baseDB) ExecOneContext(ctx context.Context, query interface{}, params ...interface{}) (Result, error) { - return db.execOne(ctx, query, params...) -} - -func (db *baseDB) execOne(c context.Context, query interface{}, params ...interface{}) (Result, error) { - res, err := db.ExecContext(c, query, params...) - if err != nil { - return nil, err - } - - if err := internal.AssertOneRow(res.RowsAffected()); err != nil { - return nil, err - } - return res, nil -} - -// Query executes a query that returns rows, typically a SELECT. -// The params are for any placeholders in the query. -func (db *baseDB) Query(model, query interface{}, params ...interface{}) (res Result, err error) { - return db.query(db.db.Context(), model, query, params...) -} - -func (db *baseDB) QueryContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) { - return db.query(c, model, query, params...) -} - -func (db *baseDB) query(ctx context.Context, model, query interface{}, params ...interface{}) (Result, error) { - wb := pool.GetWriteBuffer() - defer pool.PutWriteBuffer(wb) - - if err := writeQueryMsg(wb, db.fmter, query, params...); err != nil { - return nil, err - } - - ctx, evt, err := db.beforeQuery(ctx, db.db, model, query, params, wb.Query()) - if err != nil { - return nil, err - } - - var res Result - var lastErr error - for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ { - if attempt > 0 { - if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil { - return nil, err - } - } - - lastErr = db.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - res, err = db.simpleQueryData(ctx, cn, model, wb) - return err - }) - if !db.shouldRetry(lastErr) { - break - } - } - - if err := db.afterQuery(ctx, evt, res, lastErr); err != nil { - return nil, err - } - return res, lastErr -} - -// QueryOne acts like Query, but query must return only one row. It -// returns ErrNoRows error when query returns zero rows or -// ErrMultiRows when query returns multiple rows. -func (db *baseDB) QueryOne(model, query interface{}, params ...interface{}) (Result, error) { - return db.queryOne(db.db.Context(), model, query, params...) -} - -func (db *baseDB) QueryOneContext( - ctx context.Context, model, query interface{}, params ...interface{}, -) (Result, error) { - return db.queryOne(ctx, model, query, params...) -} - -func (db *baseDB) queryOne(ctx context.Context, model, query interface{}, params ...interface{}) (Result, error) { - res, err := db.QueryContext(ctx, model, query, params...) - if err != nil { - return nil, err - } - - if err := internal.AssertOneRow(res.RowsAffected()); err != nil { - return nil, err - } - return res, nil -} - -// CopyFrom copies data from reader to a table. -func (db *baseDB) CopyFrom(r io.Reader, query interface{}, params ...interface{}) (res Result, err error) { - c := db.db.Context() - err = db.withConn(c, func(c context.Context, cn *pool.Conn) error { - res, err = db.copyFrom(c, cn, r, query, params...) - return err - }) - return res, err -} - -// TODO: don't get/put conn in the pool. -func (db *baseDB) copyFrom( - ctx context.Context, cn *pool.Conn, r io.Reader, query interface{}, params ...interface{}, -) (res Result, err error) { - var evt *QueryEvent - - wb := pool.GetWriteBuffer() - defer pool.PutWriteBuffer(wb) - - if err := writeQueryMsg(wb, db.fmter, query, params...); err != nil { - return nil, err - } - - var model interface{} - if len(params) > 0 { - model, _ = params[len(params)-1].(orm.TableModel) - } - - ctx, evt, err = db.beforeQuery(ctx, db.db, model, query, params, wb.Query()) - if err != nil { - return nil, err - } - - // Note that afterQuery uses the err. - defer func() { - if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil { - err = afterQueryErr - } - }() - - err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - return writeQueryMsg(wb, db.fmter, query, params...) - }) - if err != nil { - return nil, err - } - - err = cn.WithReader(ctx, db.opt.ReadTimeout, readCopyInResponse) - if err != nil { - return nil, err - } - - for { - err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - return writeCopyData(wb, r) - }) - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - } - - err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - writeCopyDone(wb) - return nil - }) - if err != nil { - return nil, err - } - - err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { - res, err = readReadyForQuery(rd) - return err - }) - if err != nil { - return nil, err - } - - return res, nil -} - -// CopyTo copies data from a table to writer. -func (db *baseDB) CopyTo(w io.Writer, query interface{}, params ...interface{}) (res Result, err error) { - c := db.db.Context() - err = db.withConn(c, func(c context.Context, cn *pool.Conn) error { - res, err = db.copyTo(c, cn, w, query, params...) - return err - }) - return res, err -} - -func (db *baseDB) copyTo( - ctx context.Context, cn *pool.Conn, w io.Writer, query interface{}, params ...interface{}, -) (res Result, err error) { - var evt *QueryEvent - - wb := pool.GetWriteBuffer() - defer pool.PutWriteBuffer(wb) - - if err := writeQueryMsg(wb, db.fmter, query, params...); err != nil { - return nil, err - } - - var model interface{} - if len(params) > 0 { - model, _ = params[len(params)-1].(orm.TableModel) - } - - ctx, evt, err = db.beforeQuery(ctx, db.db, model, query, params, wb.Query()) - if err != nil { - return nil, err - } - - // Note that afterQuery uses the err. - defer func() { - if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil { - err = afterQueryErr - } - }() - - err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - return writeQueryMsg(wb, db.fmter, query, params...) - }) - if err != nil { - return nil, err - } - - err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { - err := readCopyOutResponse(rd) - if err != nil { - return err - } - - res, err = readCopyData(rd, w) - return err - }) - if err != nil { - return nil, err - } - - return res, nil -} - -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (db *baseDB) Ping(ctx context.Context) error { - _, err := db.ExecContext(ctx, "SELECT 1") - return err -} - -// Model returns new query for the model. -func (db *baseDB) Model(model ...interface{}) *Query { - return orm.NewQuery(db.db, model...) -} - -func (db *baseDB) ModelContext(c context.Context, model ...interface{}) *Query { - return orm.NewQueryContext(c, db.db, model...) -} - -func (db *baseDB) Formatter() orm.QueryFormatter { - return db.fmter -} - -func (db *baseDB) cancelRequest(processID, secretKey int32) error { - c := context.TODO() - - cn, err := db.pool.NewConn(c) - if err != nil { - return err - } - defer func() { - _ = db.pool.CloseConn(cn) - }() - - return cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - writeCancelRequestMsg(wb, processID, secretKey) - return nil - }) -} - -func (db *baseDB) simpleQuery( - c context.Context, cn *pool.Conn, wb *pool.WriteBuffer, -) (*result, error) { - if err := cn.WriteBuffer(c, db.opt.WriteTimeout, wb); err != nil { - return nil, err - } - - var res *result - if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { - var err error - res, err = readSimpleQuery(rd) - return err - }); err != nil { - return nil, err - } - - return res, nil -} - -func (db *baseDB) simpleQueryData( - c context.Context, cn *pool.Conn, model interface{}, wb *pool.WriteBuffer, -) (*result, error) { - if err := cn.WriteBuffer(c, db.opt.WriteTimeout, wb); err != nil { - return nil, err - } - - var res *result - if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { - var err error - res, err = readSimpleQueryData(c, rd, model) - return err - }); err != nil { - return nil, err - } - - return res, nil -} - -// Prepare creates a prepared statement for later queries or -// executions. Multiple queries or executions may be run concurrently -// from the returned statement. -func (db *baseDB) Prepare(q string) (*Stmt, error) { - return prepareStmt(db.withPool(pool.NewStickyConnPool(db.pool)), q) -} - -func (db *baseDB) prepare( - c context.Context, cn *pool.Conn, q string, -) (string, []types.ColumnInfo, error) { - name := cn.NextID() - err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - writeParseDescribeSyncMsg(wb, name, q) - return nil - }) - if err != nil { - return "", nil, err - } - - var columns []types.ColumnInfo - err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { - columns, err = readParseDescribeSync(rd) - return err - }) - if err != nil { - return "", nil, err - } - - return name, columns, nil -} - -func (db *baseDB) closeStmt(c context.Context, cn *pool.Conn, name string) error { - err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - writeCloseMsg(wb, name) - writeFlushMsg(wb) - return nil - }) - if err != nil { - return err - } - - err = cn.WithReader(c, db.opt.ReadTimeout, readCloseCompleteMsg) - return err -} diff --git a/vendor/github.com/go-pg/pg/v10/db.go b/vendor/github.com/go-pg/pg/v10/db.go deleted file mode 100644 index 27664783b..000000000 --- a/vendor/github.com/go-pg/pg/v10/db.go +++ /dev/null @@ -1,142 +0,0 @@ -package pg - -import ( - "context" - "fmt" - "time" - - "github.com/go-pg/pg/v10/internal/pool" - "github.com/go-pg/pg/v10/orm" -) - -// Connect connects to a database using provided options. -// -// The returned DB is safe for concurrent use by multiple goroutines -// and maintains its own connection pool. -func Connect(opt *Options) *DB { - opt.init() - return newDB( - context.Background(), - &baseDB{ - opt: opt, - pool: newConnPool(opt), - fmter: orm.NewFormatter(), - }, - ) -} - -func newDB(ctx context.Context, baseDB *baseDB) *DB { - db := &DB{ - baseDB: baseDB.clone(), - ctx: ctx, - } - db.baseDB.db = db - return db -} - -// DB is a database handle representing a pool of zero or more -// underlying connections. It's safe for concurrent use by multiple -// goroutines. -type DB struct { - *baseDB - ctx context.Context -} - -var _ orm.DB = (*DB)(nil) - -func (db *DB) String() string { - return fmt.Sprintf("DB<Addr=%q%s>", db.opt.Addr, db.fmter) -} - -// Options returns read-only Options that were used to connect to the DB. -func (db *DB) Options() *Options { - return db.opt -} - -// Context returns DB context. -func (db *DB) Context() context.Context { - return db.ctx -} - -// WithContext returns a copy of the DB that uses the ctx. -func (db *DB) WithContext(ctx context.Context) *DB { - return newDB(ctx, db.baseDB) -} - -// WithTimeout returns a copy of the DB that uses d as the read/write timeout. -func (db *DB) WithTimeout(d time.Duration) *DB { - return newDB(db.ctx, db.baseDB.WithTimeout(d)) -} - -// WithParam returns a copy of the DB that replaces the param with the value -// in queries. -func (db *DB) WithParam(param string, value interface{}) *DB { - return newDB(db.ctx, db.baseDB.WithParam(param, value)) -} - -// Listen listens for notifications sent with NOTIFY command. -func (db *DB) Listen(ctx context.Context, channels ...string) *Listener { - ln := &Listener{ - db: db, - } - ln.init() - _ = ln.Listen(ctx, channels...) - return ln -} - -// Conn represents a single database connection rather than a pool of database -// connections. Prefer running queries from DB unless there is a specific -// need for a continuous single database connection. -// -// A Conn must call Close to return the connection to the database pool -// and may do so concurrently with a running query. -// -// After a call to Close, all operations on the connection fail. -type Conn struct { - *baseDB - ctx context.Context -} - -var _ orm.DB = (*Conn)(nil) - -// Conn returns a single connection from the connection pool. -// Queries run on the same Conn will be run in the same database session. -// -// Every Conn must be returned to the database pool after use by -// calling Conn.Close. -func (db *DB) Conn() *Conn { - return newConn(db.ctx, db.baseDB.withPool(pool.NewStickyConnPool(db.pool))) -} - -func newConn(ctx context.Context, baseDB *baseDB) *Conn { - conn := &Conn{ - baseDB: baseDB, - ctx: ctx, - } - conn.baseDB.db = conn - return conn -} - -// Context returns DB context. -func (db *Conn) Context() context.Context { - if db.ctx != nil { - return db.ctx - } - return context.Background() -} - -// WithContext returns a copy of the DB that uses the ctx. -func (db *Conn) WithContext(ctx context.Context) *Conn { - return newConn(ctx, db.baseDB) -} - -// WithTimeout returns a copy of the DB that uses d as the read/write timeout. -func (db *Conn) WithTimeout(d time.Duration) *Conn { - return newConn(db.ctx, db.baseDB.WithTimeout(d)) -} - -// WithParam returns a copy of the DB that replaces the param with the value -// in queries. -func (db *Conn) WithParam(param string, value interface{}) *Conn { - return newConn(db.ctx, db.baseDB.WithParam(param, value)) -} diff --git a/vendor/github.com/go-pg/pg/v10/doc.go b/vendor/github.com/go-pg/pg/v10/doc.go deleted file mode 100644 index 9a077a8c1..000000000 --- a/vendor/github.com/go-pg/pg/v10/doc.go +++ /dev/null @@ -1,4 +0,0 @@ -/* -pg provides PostgreSQL client. -*/ -package pg diff --git a/vendor/github.com/go-pg/pg/v10/error.go b/vendor/github.com/go-pg/pg/v10/error.go deleted file mode 100644 index d8113a010..000000000 --- a/vendor/github.com/go-pg/pg/v10/error.go +++ /dev/null @@ -1,69 +0,0 @@ -package pg - -import ( - "net" - - "github.com/go-pg/pg/v10/internal" -) - -// ErrNoRows is returned by QueryOne and ExecOne when query returned zero rows -// but at least one row is expected. -var ErrNoRows = internal.ErrNoRows - -// ErrMultiRows is returned by QueryOne and ExecOne when query returned -// multiple rows but exactly one row is expected. -var ErrMultiRows = internal.ErrMultiRows - -// Error represents an error returned by PostgreSQL server -// using PostgreSQL ErrorResponse protocol. -// -// https://www.postgresql.org/docs/10/static/protocol-message-formats.html -type Error interface { - error - - // Field returns a string value associated with an error field. - // - // https://www.postgresql.org/docs/10/static/protocol-error-fields.html - Field(field byte) string - - // IntegrityViolation reports whether an error is a part of - // Integrity Constraint Violation class of errors. - // - // https://www.postgresql.org/docs/10/static/errcodes-appendix.html - IntegrityViolation() bool -} - -var _ Error = (*internal.PGError)(nil) - -func isBadConn(err error, allowTimeout bool) bool { - if err == nil { - return false - } - if _, ok := err.(internal.Error); ok { - return false - } - if pgErr, ok := err.(Error); ok { - switch pgErr.Field('V') { - case "FATAL", "PANIC": - return true - } - switch pgErr.Field('C') { - case "25P02", // current transaction is aborted - "57014": // canceling statement due to user request - return true - } - return false - } - if allowTimeout { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - return !netErr.Temporary() - } - } - return true -} - -//------------------------------------------------------------------------------ - -type timeoutError interface { - Timeout() bool -} diff --git a/vendor/github.com/go-pg/pg/v10/go.mod b/vendor/github.com/go-pg/pg/v10/go.mod deleted file mode 100644 index aa867f309..000000000 --- a/vendor/github.com/go-pg/pg/v10/go.mod +++ /dev/null @@ -1,24 +0,0 @@ -module github.com/go-pg/pg/v10 - -go 1.11 - -require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-pg/zerochecker v0.2.0 - github.com/golang/protobuf v1.4.3 // indirect - github.com/google/go-cmp v0.5.5 // indirect - github.com/jinzhu/inflection v1.0.0 - github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect - github.com/onsi/ginkgo v1.14.2 - github.com/onsi/gomega v1.10.3 - github.com/stretchr/testify v1.7.0 - github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc - github.com/vmihailenco/bufpool v0.1.11 - github.com/vmihailenco/msgpack/v5 v5.3.1 - github.com/vmihailenco/tagparser v0.1.2 - golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b // indirect - golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect - google.golang.org/protobuf v1.25.0 // indirect - gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f - mellium.im/sasl v0.2.1 -) diff --git a/vendor/github.com/go-pg/pg/v10/go.sum b/vendor/github.com/go-pg/pg/v10/go.sum deleted file mode 100644 index 7d2d87c0b..000000000 --- a/vendor/github.com/go-pg/pg/v10/go.sum +++ /dev/null @@ -1,154 +0,0 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= -github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU= -github.com/go-pg/zerochecker v0.2.0/go.mod h1:NJZ4wKL0NmTtz0GKCoJ8kym6Xn/EQzXRl2OnAe7MmDo= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= -github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M= -github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= -github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.10.3 h1:gph6h/qe9GSUw1NhH1gp+qb+h8rXD8Cy60Z32Qw3ELA= -github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= -github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= -github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94= -github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ= -github.com/vmihailenco/msgpack/v5 v5.3.1 h1:0i85a4dsZh8mC//wmyyTEzidDLPQfQAxZIOLtafGbFY= -github.com/vmihailenco/msgpack/v5 v5.3.1/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= -github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc= -github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= -github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= -github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= -golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b h1:7mWr3k41Qtv8XlltBkDkl8LoP3mpSgBW8BUoxtEdbXg= -golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 h1:iGu644GcxtEcrInvDsQRCwJjtCIOlT2V7IRt6ah2Whw= -golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w= -mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ= diff --git a/vendor/github.com/go-pg/pg/v10/hook.go b/vendor/github.com/go-pg/pg/v10/hook.go deleted file mode 100644 index a95dc20bc..000000000 --- a/vendor/github.com/go-pg/pg/v10/hook.go +++ /dev/null @@ -1,139 +0,0 @@ -package pg - -import ( - "context" - "fmt" - "time" - - "github.com/go-pg/pg/v10/orm" -) - -type ( - BeforeScanHook = orm.BeforeScanHook - AfterScanHook = orm.AfterScanHook - AfterSelectHook = orm.AfterSelectHook - BeforeInsertHook = orm.BeforeInsertHook - AfterInsertHook = orm.AfterInsertHook - BeforeUpdateHook = orm.BeforeUpdateHook - AfterUpdateHook = orm.AfterUpdateHook - BeforeDeleteHook = orm.BeforeDeleteHook - AfterDeleteHook = orm.AfterDeleteHook -) - -//------------------------------------------------------------------------------ - -type dummyFormatter struct{} - -func (dummyFormatter) FormatQuery(b []byte, query string, params ...interface{}) []byte { - return append(b, query...) -} - -// QueryEvent ... -type QueryEvent struct { - StartTime time.Time - DB orm.DB - Model interface{} - Query interface{} - Params []interface{} - fmtedQuery []byte - Result Result - Err error - - Stash map[interface{}]interface{} -} - -// QueryHook ... -type QueryHook interface { - BeforeQuery(context.Context, *QueryEvent) (context.Context, error) - AfterQuery(context.Context, *QueryEvent) error -} - -// UnformattedQuery returns the unformatted query of a query event. -// The query is only valid until the query Result is returned to the user. -func (e *QueryEvent) UnformattedQuery() ([]byte, error) { - return queryString(e.Query) -} - -func queryString(query interface{}) ([]byte, error) { - switch query := query.(type) { - case orm.TemplateAppender: - return query.AppendTemplate(nil) - case string: - return dummyFormatter{}.FormatQuery(nil, query), nil - default: - return nil, fmt.Errorf("pg: can't append %T", query) - } -} - -// FormattedQuery returns the formatted query of a query event. -// The query is only valid until the query Result is returned to the user. -func (e *QueryEvent) FormattedQuery() ([]byte, error) { - return e.fmtedQuery, nil -} - -// AddQueryHook adds a hook into query processing. -func (db *baseDB) AddQueryHook(hook QueryHook) { - db.queryHooks = append(db.queryHooks, hook) -} - -func (db *baseDB) beforeQuery( - ctx context.Context, - ormDB orm.DB, - model, query interface{}, - params []interface{}, - fmtedQuery []byte, -) (context.Context, *QueryEvent, error) { - if len(db.queryHooks) == 0 { - return ctx, nil, nil - } - - event := &QueryEvent{ - StartTime: time.Now(), - DB: ormDB, - Model: model, - Query: query, - Params: params, - fmtedQuery: fmtedQuery, - } - - for i, hook := range db.queryHooks { - var err error - ctx, err = hook.BeforeQuery(ctx, event) - if err != nil { - if err := db.afterQueryFromIndex(ctx, event, i); err != nil { - return ctx, nil, err - } - return ctx, nil, err - } - } - - return ctx, event, nil -} - -func (db *baseDB) afterQuery( - ctx context.Context, - event *QueryEvent, - res Result, - err error, -) error { - if event == nil { - return nil - } - - event.Err = err - event.Result = res - return db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1) -} - -func (db *baseDB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) error { - for ; hookIndex >= 0; hookIndex-- { - if err := db.queryHooks[hookIndex].AfterQuery(ctx, event); err != nil { - return err - } - } - return nil -} - -func copyQueryHooks(s []QueryHook) []QueryHook { - return s[:len(s):len(s)] -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/context.go b/vendor/github.com/go-pg/pg/v10/internal/context.go deleted file mode 100644 index 06d20c152..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/context.go +++ /dev/null @@ -1,26 +0,0 @@ -package internal - -import ( - "context" - "time" -) - -type UndoneContext struct { - context.Context -} - -func UndoContext(ctx context.Context) UndoneContext { - return UndoneContext{Context: ctx} -} - -func (UndoneContext) Deadline() (deadline time.Time, ok bool) { - return time.Time{}, false -} - -func (UndoneContext) Done() <-chan struct{} { - return nil -} - -func (UndoneContext) Err() error { - return nil -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/error.go b/vendor/github.com/go-pg/pg/v10/internal/error.go deleted file mode 100644 index ae6524aeb..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/error.go +++ /dev/null @@ -1,61 +0,0 @@ -package internal - -import ( - "fmt" -) - -var ( - ErrNoRows = Errorf("pg: no rows in result set") - ErrMultiRows = Errorf("pg: multiple rows in result set") -) - -type Error struct { - s string -} - -func Errorf(s string, args ...interface{}) Error { - return Error{s: fmt.Sprintf(s, args...)} -} - -func (err Error) Error() string { - return err.s -} - -type PGError struct { - m map[byte]string -} - -func NewPGError(m map[byte]string) PGError { - return PGError{ - m: m, - } -} - -func (err PGError) Field(k byte) string { - return err.m[k] -} - -func (err PGError) IntegrityViolation() bool { - switch err.Field('C') { - case "23000", "23001", "23502", "23503", "23505", "23514", "23P01": - return true - default: - return false - } -} - -func (err PGError) Error() string { - return fmt.Sprintf("%s #%s %s", - err.Field('S'), err.Field('C'), err.Field('M')) -} - -func AssertOneRow(l int) error { - switch { - case l == 0: - return ErrNoRows - case l > 1: - return ErrMultiRows - default: - return nil - } -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/internal.go b/vendor/github.com/go-pg/pg/v10/internal/internal.go deleted file mode 100644 index bda5028c6..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/internal.go +++ /dev/null @@ -1,27 +0,0 @@ -/* -internal is a private internal package. -*/ -package internal - -import ( - "math/rand" - "time" -) - -func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration { - if retry < 0 { - panic("not reached") - } - if minBackoff == 0 { - return 0 - } - - d := minBackoff << uint(retry) - d = minBackoff + time.Duration(rand.Int63n(int64(d))) - - if d > maxBackoff || d < minBackoff { - d = maxBackoff - } - - return d -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/log.go b/vendor/github.com/go-pg/pg/v10/internal/log.go deleted file mode 100644 index 7ea547b10..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/log.go +++ /dev/null @@ -1,28 +0,0 @@ -package internal - -import ( - "context" - "fmt" - "log" - "os" -) - -var Warn = log.New(os.Stderr, "WARN: pg: ", log.LstdFlags) - -var Deprecated = log.New(os.Stderr, "DEPRECATED: pg: ", log.LstdFlags) - -type Logging interface { - Printf(ctx context.Context, format string, v ...interface{}) -} - -type logger struct { - log *log.Logger -} - -func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) { - _ = l.log.Output(2, fmt.Sprintf(format, v...)) -} - -var Logger Logging = &logger{ - log: log.New(os.Stderr, "pg: ", log.LstdFlags|log.Lshortfile), -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/parser/streaming_parser.go b/vendor/github.com/go-pg/pg/v10/internal/parser/streaming_parser.go deleted file mode 100644 index 723c12b16..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/parser/streaming_parser.go +++ /dev/null @@ -1,65 +0,0 @@ -package parser - -import ( - "fmt" - - "github.com/go-pg/pg/v10/internal/pool" -) - -type StreamingParser struct { - pool.Reader -} - -func NewStreamingParser(rd pool.Reader) StreamingParser { - return StreamingParser{ - Reader: rd, - } -} - -func (p StreamingParser) SkipByte(skip byte) error { - c, err := p.ReadByte() - if err != nil { - return err - } - if c == skip { - return nil - } - _ = p.UnreadByte() - return fmt.Errorf("got %q, wanted %q", c, skip) -} - -func (p StreamingParser) ReadSubstring(b []byte) ([]byte, error) { - c, err := p.ReadByte() - if err != nil { - return b, err - } - - for { - if c == '"' { - return b, nil - } - - next, err := p.ReadByte() - if err != nil { - return b, err - } - - if c == '\\' { - switch next { - case '\\', '"': - b = append(b, next) - c, err = p.ReadByte() - if err != nil { - return nil, err - } - default: - b = append(b, '\\') - c = next - } - continue - } - - b = append(b, c) - c = next - } -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/conn.go b/vendor/github.com/go-pg/pg/v10/internal/pool/conn.go deleted file mode 100644 index 91045245b..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/pool/conn.go +++ /dev/null @@ -1,158 +0,0 @@ -package pool - -import ( - "context" - "net" - "strconv" - "sync/atomic" - "time" -) - -var noDeadline = time.Time{} - -type Conn struct { - netConn net.Conn - rd *ReaderContext - - ProcessID int32 - SecretKey int32 - lastID int64 - - createdAt time.Time - usedAt uint32 // atomic - pooled bool - Inited bool -} - -func NewConn(netConn net.Conn) *Conn { - cn := &Conn{ - createdAt: time.Now(), - } - cn.SetNetConn(netConn) - cn.SetUsedAt(time.Now()) - return cn -} - -func (cn *Conn) UsedAt() time.Time { - unix := atomic.LoadUint32(&cn.usedAt) - return time.Unix(int64(unix), 0) -} - -func (cn *Conn) SetUsedAt(tm time.Time) { - atomic.StoreUint32(&cn.usedAt, uint32(tm.Unix())) -} - -func (cn *Conn) RemoteAddr() net.Addr { - return cn.netConn.RemoteAddr() -} - -func (cn *Conn) SetNetConn(netConn net.Conn) { - cn.netConn = netConn - if cn.rd != nil { - cn.rd.Reset(netConn) - } -} - -func (cn *Conn) LockReader() { - if cn.rd != nil { - panic("not reached") - } - cn.rd = NewReaderContext() - cn.rd.Reset(cn.netConn) -} - -func (cn *Conn) NetConn() net.Conn { - return cn.netConn -} - -func (cn *Conn) NextID() string { - cn.lastID++ - return strconv.FormatInt(cn.lastID, 10) -} - -func (cn *Conn) WithReader( - ctx context.Context, timeout time.Duration, fn func(rd *ReaderContext) error, -) error { - if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { - return err - } - - rd := cn.rd - if rd == nil { - rd = GetReaderContext() - defer PutReaderContext(rd) - - rd.Reset(cn.netConn) - } - - rd.bytesRead = 0 - - if err := fn(rd); err != nil { - return err - } - - return nil -} - -func (cn *Conn) WithWriter( - ctx context.Context, timeout time.Duration, fn func(wb *WriteBuffer) error, -) error { - wb := GetWriteBuffer() - defer PutWriteBuffer(wb) - - if err := fn(wb); err != nil { - return err - } - - return cn.writeBuffer(ctx, timeout, wb) -} - -func (cn *Conn) WriteBuffer(ctx context.Context, timeout time.Duration, wb *WriteBuffer) error { - return cn.writeBuffer(ctx, timeout, wb) -} - -func (cn *Conn) writeBuffer( - ctx context.Context, - timeout time.Duration, - wb *WriteBuffer, -) error { - if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil { - return err - } - if _, err := cn.netConn.Write(wb.Bytes); err != nil { - return err - } - return nil -} - -func (cn *Conn) Close() error { - return cn.netConn.Close() -} - -func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { - tm := time.Now() - cn.SetUsedAt(tm) - - if timeout > 0 { - tm = tm.Add(timeout) - } - - if ctx != nil { - deadline, ok := ctx.Deadline() - if ok { - if timeout == 0 { - return deadline - } - if deadline.Before(tm) { - return deadline - } - return tm - } - } - - if timeout > 0 { - return tm - } - - return noDeadline -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/pool.go b/vendor/github.com/go-pg/pg/v10/internal/pool/pool.go deleted file mode 100644 index 59f2c72d0..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/pool/pool.go +++ /dev/null @@ -1,506 +0,0 @@ -package pool - -import ( - "context" - "errors" - "net" - "sync" - "sync/atomic" - "time" - - "github.com/go-pg/pg/v10/internal" -) - -var ( - ErrClosed = errors.New("pg: database is closed") - ErrPoolTimeout = errors.New("pg: connection pool timeout") -) - -var timers = sync.Pool{ - New: func() interface{} { - t := time.NewTimer(time.Hour) - t.Stop() - return t - }, -} - -// Stats contains pool state information and accumulated stats. -type Stats struct { - Hits uint32 // number of times free connection was found in the pool - Misses uint32 // number of times free connection was NOT found in the pool - Timeouts uint32 // number of times a wait timeout occurred - - TotalConns uint32 // number of total connections in the pool - IdleConns uint32 // number of idle connections in the pool - StaleConns uint32 // number of stale connections removed from the pool -} - -type Pooler interface { - NewConn(context.Context) (*Conn, error) - CloseConn(*Conn) error - - Get(context.Context) (*Conn, error) - Put(context.Context, *Conn) - Remove(context.Context, *Conn, error) - - Len() int - IdleLen() int - Stats() *Stats - - Close() error -} - -type Options struct { - Dialer func(context.Context) (net.Conn, error) - OnClose func(*Conn) error - - PoolSize int - MinIdleConns int - MaxConnAge time.Duration - PoolTimeout time.Duration - IdleTimeout time.Duration - IdleCheckFrequency time.Duration -} - -type ConnPool struct { - opt *Options - - dialErrorsNum uint32 // atomic - - _closed uint32 // atomic - - lastDialErrorMu sync.RWMutex - lastDialError error - - queue chan struct{} - - stats Stats - - connsMu sync.Mutex - conns []*Conn - idleConns []*Conn - - poolSize int - idleConnsLen int -} - -var _ Pooler = (*ConnPool)(nil) - -func NewConnPool(opt *Options) *ConnPool { - p := &ConnPool{ - opt: opt, - - queue: make(chan struct{}, opt.PoolSize), - conns: make([]*Conn, 0, opt.PoolSize), - idleConns: make([]*Conn, 0, opt.PoolSize), - } - - p.connsMu.Lock() - p.checkMinIdleConns() - p.connsMu.Unlock() - - if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 { - go p.reaper(opt.IdleCheckFrequency) - } - - return p -} - -func (p *ConnPool) checkMinIdleConns() { - if p.opt.MinIdleConns == 0 { - return - } - for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns { - p.poolSize++ - p.idleConnsLen++ - go func() { - err := p.addIdleConn() - if err != nil { - p.connsMu.Lock() - p.poolSize-- - p.idleConnsLen-- - p.connsMu.Unlock() - } - }() - } -} - -func (p *ConnPool) addIdleConn() error { - cn, err := p.dialConn(context.TODO(), true) - if err != nil { - return err - } - - p.connsMu.Lock() - p.conns = append(p.conns, cn) - p.idleConns = append(p.idleConns, cn) - p.connsMu.Unlock() - return nil -} - -func (p *ConnPool) NewConn(c context.Context) (*Conn, error) { - return p.newConn(c, false) -} - -func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) { - cn, err := p.dialConn(c, pooled) - if err != nil { - return nil, err - } - - p.connsMu.Lock() - - p.conns = append(p.conns, cn) - if pooled { - // If pool is full remove the cn on next Put. - if p.poolSize >= p.opt.PoolSize { - cn.pooled = false - } else { - p.poolSize++ - } - } - - p.connsMu.Unlock() - return cn, nil -} - -func (p *ConnPool) dialConn(c context.Context, pooled bool) (*Conn, error) { - if p.closed() { - return nil, ErrClosed - } - - if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) { - return nil, p.getLastDialError() - } - - netConn, err := p.opt.Dialer(c) - if err != nil { - p.setLastDialError(err) - if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { - go p.tryDial() - } - return nil, err - } - - cn := NewConn(netConn) - cn.pooled = pooled - return cn, nil -} - -func (p *ConnPool) tryDial() { - for { - if p.closed() { - return - } - - conn, err := p.opt.Dialer(context.TODO()) - if err != nil { - p.setLastDialError(err) - time.Sleep(time.Second) - continue - } - - atomic.StoreUint32(&p.dialErrorsNum, 0) - _ = conn.Close() - return - } -} - -func (p *ConnPool) setLastDialError(err error) { - p.lastDialErrorMu.Lock() - p.lastDialError = err - p.lastDialErrorMu.Unlock() -} - -func (p *ConnPool) getLastDialError() error { - p.lastDialErrorMu.RLock() - err := p.lastDialError - p.lastDialErrorMu.RUnlock() - return err -} - -// Get returns existed connection from the pool or creates a new one. -func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { - if p.closed() { - return nil, ErrClosed - } - - err := p.waitTurn(ctx) - if err != nil { - return nil, err - } - - for { - p.connsMu.Lock() - cn := p.popIdle() - p.connsMu.Unlock() - - if cn == nil { - break - } - - if p.isStaleConn(cn) { - _ = p.CloseConn(cn) - continue - } - - atomic.AddUint32(&p.stats.Hits, 1) - return cn, nil - } - - atomic.AddUint32(&p.stats.Misses, 1) - - newcn, err := p.newConn(ctx, true) - if err != nil { - p.freeTurn() - return nil, err - } - - return newcn, nil -} - -func (p *ConnPool) getTurn() { - p.queue <- struct{}{} -} - -func (p *ConnPool) waitTurn(c context.Context) error { - select { - case <-c.Done(): - return c.Err() - default: - } - - select { - case p.queue <- struct{}{}: - return nil - default: - } - - timer := timers.Get().(*time.Timer) - timer.Reset(p.opt.PoolTimeout) - - select { - case <-c.Done(): - if !timer.Stop() { - <-timer.C - } - timers.Put(timer) - return c.Err() - case p.queue <- struct{}{}: - if !timer.Stop() { - <-timer.C - } - timers.Put(timer) - return nil - case <-timer.C: - timers.Put(timer) - atomic.AddUint32(&p.stats.Timeouts, 1) - return ErrPoolTimeout - } -} - -func (p *ConnPool) freeTurn() { - <-p.queue -} - -func (p *ConnPool) popIdle() *Conn { - if len(p.idleConns) == 0 { - return nil - } - - idx := len(p.idleConns) - 1 - cn := p.idleConns[idx] - p.idleConns = p.idleConns[:idx] - p.idleConnsLen-- - p.checkMinIdleConns() - return cn -} - -func (p *ConnPool) Put(ctx context.Context, cn *Conn) { - if !cn.pooled { - p.Remove(ctx, cn, nil) - return - } - - p.connsMu.Lock() - p.idleConns = append(p.idleConns, cn) - p.idleConnsLen++ - p.connsMu.Unlock() - p.freeTurn() -} - -func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { - p.removeConnWithLock(cn) - p.freeTurn() - _ = p.closeConn(cn) -} - -func (p *ConnPool) CloseConn(cn *Conn) error { - p.removeConnWithLock(cn) - return p.closeConn(cn) -} - -func (p *ConnPool) removeConnWithLock(cn *Conn) { - p.connsMu.Lock() - p.removeConn(cn) - p.connsMu.Unlock() -} - -func (p *ConnPool) removeConn(cn *Conn) { - for i, c := range p.conns { - if c == cn { - p.conns = append(p.conns[:i], p.conns[i+1:]...) - if cn.pooled { - p.poolSize-- - p.checkMinIdleConns() - } - return - } - } -} - -func (p *ConnPool) closeConn(cn *Conn) error { - if p.opt.OnClose != nil { - _ = p.opt.OnClose(cn) - } - return cn.Close() -} - -// Len returns total number of connections. -func (p *ConnPool) Len() int { - p.connsMu.Lock() - n := len(p.conns) - p.connsMu.Unlock() - return n -} - -// IdleLen returns number of idle connections. -func (p *ConnPool) IdleLen() int { - p.connsMu.Lock() - n := p.idleConnsLen - p.connsMu.Unlock() - return n -} - -func (p *ConnPool) Stats() *Stats { - idleLen := p.IdleLen() - return &Stats{ - Hits: atomic.LoadUint32(&p.stats.Hits), - Misses: atomic.LoadUint32(&p.stats.Misses), - Timeouts: atomic.LoadUint32(&p.stats.Timeouts), - - TotalConns: uint32(p.Len()), - IdleConns: uint32(idleLen), - StaleConns: atomic.LoadUint32(&p.stats.StaleConns), - } -} - -func (p *ConnPool) closed() bool { - return atomic.LoadUint32(&p._closed) == 1 -} - -func (p *ConnPool) Filter(fn func(*Conn) bool) error { - var firstErr error - p.connsMu.Lock() - for _, cn := range p.conns { - if fn(cn) { - if err := p.closeConn(cn); err != nil && firstErr == nil { - firstErr = err - } - } - } - p.connsMu.Unlock() - return firstErr -} - -func (p *ConnPool) Close() error { - if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) { - return ErrClosed - } - - var firstErr error - p.connsMu.Lock() - for _, cn := range p.conns { - if err := p.closeConn(cn); err != nil && firstErr == nil { - firstErr = err - } - } - p.conns = nil - p.poolSize = 0 - p.idleConns = nil - p.idleConnsLen = 0 - p.connsMu.Unlock() - - return firstErr -} - -func (p *ConnPool) reaper(frequency time.Duration) { - ticker := time.NewTicker(frequency) - defer ticker.Stop() - - for range ticker.C { - if p.closed() { - break - } - n, err := p.ReapStaleConns() - if err != nil { - internal.Logger.Printf(context.TODO(), "ReapStaleConns failed: %s", err) - continue - } - atomic.AddUint32(&p.stats.StaleConns, uint32(n)) - } -} - -func (p *ConnPool) ReapStaleConns() (int, error) { - var n int - for { - p.getTurn() - - p.connsMu.Lock() - cn := p.reapStaleConn() - p.connsMu.Unlock() - - p.freeTurn() - - if cn != nil { - _ = p.closeConn(cn) - n++ - } else { - break - } - } - return n, nil -} - -func (p *ConnPool) reapStaleConn() *Conn { - if len(p.idleConns) == 0 { - return nil - } - - cn := p.idleConns[0] - if !p.isStaleConn(cn) { - return nil - } - - p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) - p.idleConnsLen-- - p.removeConn(cn) - - return cn -} - -func (p *ConnPool) isStaleConn(cn *Conn) bool { - if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { - return false - } - - now := time.Now() - if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout { - return true - } - if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge { - return true - } - - return false -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/pool_single.go b/vendor/github.com/go-pg/pg/v10/internal/pool/pool_single.go deleted file mode 100644 index 5a3fde191..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/pool/pool_single.go +++ /dev/null @@ -1,58 +0,0 @@ -package pool - -import "context" - -type SingleConnPool struct { - pool Pooler - cn *Conn - stickyErr error -} - -var _ Pooler = (*SingleConnPool)(nil) - -func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool { - return &SingleConnPool{ - pool: pool, - cn: cn, - } -} - -func (p *SingleConnPool) NewConn(ctx context.Context) (*Conn, error) { - return p.pool.NewConn(ctx) -} - -func (p *SingleConnPool) CloseConn(cn *Conn) error { - return p.pool.CloseConn(cn) -} - -func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { - if p.stickyErr != nil { - return nil, p.stickyErr - } - return p.cn, nil -} - -func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {} - -func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) { - p.cn = nil - p.stickyErr = reason -} - -func (p *SingleConnPool) Close() error { - p.cn = nil - p.stickyErr = ErrClosed - return nil -} - -func (p *SingleConnPool) Len() int { - return 0 -} - -func (p *SingleConnPool) IdleLen() int { - return 0 -} - -func (p *SingleConnPool) Stats() *Stats { - return &Stats{} -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/pool_sticky.go b/vendor/github.com/go-pg/pg/v10/internal/pool/pool_sticky.go deleted file mode 100644 index 0415b5e87..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/pool/pool_sticky.go +++ /dev/null @@ -1,202 +0,0 @@ -package pool - -import ( - "context" - "errors" - "fmt" - "sync/atomic" -) - -const ( - stateDefault = 0 - stateInited = 1 - stateClosed = 2 -) - -type BadConnError struct { - wrapped error -} - -var _ error = (*BadConnError)(nil) - -func (e BadConnError) Error() string { - s := "pg: Conn is in a bad state" - if e.wrapped != nil { - s += ": " + e.wrapped.Error() - } - return s -} - -func (e BadConnError) Unwrap() error { - return e.wrapped -} - -//------------------------------------------------------------------------------ - -type StickyConnPool struct { - pool Pooler - shared int32 // atomic - - state uint32 // atomic - ch chan *Conn - - _badConnError atomic.Value -} - -var _ Pooler = (*StickyConnPool)(nil) - -func NewStickyConnPool(pool Pooler) *StickyConnPool { - p, ok := pool.(*StickyConnPool) - if !ok { - p = &StickyConnPool{ - pool: pool, - ch: make(chan *Conn, 1), - } - } - atomic.AddInt32(&p.shared, 1) - return p -} - -func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) { - return p.pool.NewConn(ctx) -} - -func (p *StickyConnPool) CloseConn(cn *Conn) error { - return p.pool.CloseConn(cn) -} - -func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) { - // In worst case this races with Close which is not a very common operation. - for i := 0; i < 1000; i++ { - switch atomic.LoadUint32(&p.state) { - case stateDefault: - cn, err := p.pool.Get(ctx) - if err != nil { - return nil, err - } - if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { - return cn, nil - } - p.pool.Remove(ctx, cn, ErrClosed) - case stateInited: - if err := p.badConnError(); err != nil { - return nil, err - } - cn, ok := <-p.ch - if !ok { - return nil, ErrClosed - } - return cn, nil - case stateClosed: - return nil, ErrClosed - default: - panic("not reached") - } - } - return nil, fmt.Errorf("pg: StickyConnPool.Get: infinite loop") -} - -func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) { - defer func() { - if recover() != nil { - p.freeConn(ctx, cn) - } - }() - p.ch <- cn -} - -func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) { - if err := p.badConnError(); err != nil { - p.pool.Remove(ctx, cn, err) - } else { - p.pool.Put(ctx, cn) - } -} - -func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) { - defer func() { - if recover() != nil { - p.pool.Remove(ctx, cn, ErrClosed) - } - }() - p._badConnError.Store(BadConnError{wrapped: reason}) - p.ch <- cn -} - -func (p *StickyConnPool) Close() error { - if shared := atomic.AddInt32(&p.shared, -1); shared > 0 { - return nil - } - - for i := 0; i < 1000; i++ { - state := atomic.LoadUint32(&p.state) - if state == stateClosed { - return ErrClosed - } - if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { - close(p.ch) - cn, ok := <-p.ch - if ok { - p.freeConn(context.TODO(), cn) - } - return nil - } - } - - return errors.New("pg: StickyConnPool.Close: infinite loop") -} - -func (p *StickyConnPool) Reset(ctx context.Context) error { - if p.badConnError() == nil { - return nil - } - - select { - case cn, ok := <-p.ch: - if !ok { - return ErrClosed - } - p.pool.Remove(ctx, cn, ErrClosed) - p._badConnError.Store(BadConnError{wrapped: nil}) - default: - return errors.New("pg: StickyConnPool does not have a Conn") - } - - if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { - state := atomic.LoadUint32(&p.state) - return fmt.Errorf("pg: invalid StickyConnPool state: %d", state) - } - - return nil -} - -func (p *StickyConnPool) badConnError() error { - if v := p._badConnError.Load(); v != nil { - err := v.(BadConnError) - if err.wrapped != nil { - return err - } - } - return nil -} - -func (p *StickyConnPool) Len() int { - switch atomic.LoadUint32(&p.state) { - case stateDefault: - return 0 - case stateInited: - return 1 - case stateClosed: - return 0 - default: - panic("not reached") - } -} - -func (p *StickyConnPool) IdleLen() int { - return len(p.ch) -} - -func (p *StickyConnPool) Stats() *Stats { - return &Stats{} -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/reader.go b/vendor/github.com/go-pg/pg/v10/internal/pool/reader.go deleted file mode 100644 index b5d00807d..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/pool/reader.go +++ /dev/null @@ -1,80 +0,0 @@ -package pool - -import ( - "sync" -) - -type Reader interface { - Buffered() int - - Bytes() []byte - Read([]byte) (int, error) - ReadByte() (byte, error) - UnreadByte() error - ReadSlice(byte) ([]byte, error) - Discard(int) (int, error) - - // ReadBytes(fn func(byte) bool) ([]byte, error) - // ReadN(int) ([]byte, error) - ReadFull() ([]byte, error) - ReadFullTemp() ([]byte, error) -} - -type ColumnInfo struct { - Index int16 - DataType int32 - Name string -} - -type ColumnAlloc struct { - columns []ColumnInfo -} - -func NewColumnAlloc() *ColumnAlloc { - return new(ColumnAlloc) -} - -func (c *ColumnAlloc) Reset() { - c.columns = c.columns[:0] -} - -func (c *ColumnAlloc) New(index int16, name []byte) *ColumnInfo { - c.columns = append(c.columns, ColumnInfo{ - Index: index, - Name: string(name), - }) - return &c.columns[len(c.columns)-1] -} - -func (c *ColumnAlloc) Columns() []ColumnInfo { - return c.columns -} - -type ReaderContext struct { - *BufReader - ColumnAlloc *ColumnAlloc -} - -func NewReaderContext() *ReaderContext { - const bufSize = 1 << 20 // 1mb - return &ReaderContext{ - BufReader: NewBufReader(bufSize), - ColumnAlloc: NewColumnAlloc(), - } -} - -var readerPool = sync.Pool{ - New: func() interface{} { - return NewReaderContext() - }, -} - -func GetReaderContext() *ReaderContext { - rd := readerPool.Get().(*ReaderContext) - return rd -} - -func PutReaderContext(rd *ReaderContext) { - rd.ColumnAlloc.Reset() - readerPool.Put(rd) -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/reader_buf.go b/vendor/github.com/go-pg/pg/v10/internal/pool/reader_buf.go deleted file mode 100644 index 3172e8b05..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/pool/reader_buf.go +++ /dev/null @@ -1,431 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package pool - -import ( - "bufio" - "bytes" - "io" -) - -type BufReader struct { - rd io.Reader // reader provided by the client - - buf []byte - r, w int // buf read and write positions - lastByte int - bytesRead int64 - err error - - available int // bytes available for reading - brd BytesReader // reusable bytes reader -} - -func NewBufReader(bufSize int) *BufReader { - return &BufReader{ - buf: make([]byte, bufSize), - available: -1, - } -} - -func (b *BufReader) BytesReader(n int) *BytesReader { - if n == -1 { - n = 0 - } - buf := b.buf[b.r : b.r+n] - b.r += n - b.brd.Reset(buf) - return &b.brd -} - -func (b *BufReader) SetAvailable(n int) { - b.available = n -} - -func (b *BufReader) Available() int { - return b.available -} - -func (b *BufReader) changeAvailable(n int) { - if b.available != -1 { - b.available += n - } -} - -func (b *BufReader) Reset(rd io.Reader) { - b.rd = rd - b.r, b.w = 0, 0 - b.err = nil -} - -// Buffered returns the number of bytes that can be read from the current buffer. -func (b *BufReader) Buffered() int { - buffered := b.w - b.r - if b.available == -1 || buffered <= b.available { - return buffered - } - return b.available -} - -func (b *BufReader) Bytes() []byte { - if b.available == -1 { - return b.buf[b.r:b.w] - } - w := b.r + b.available - if w > b.w { - w = b.w - } - return b.buf[b.r:w] -} - -func (b *BufReader) flush() []byte { - if b.available == -1 { - buf := b.buf[b.r:b.w] - b.r = b.w - return buf - } - - w := b.r + b.available - if w > b.w { - w = b.w - } - buf := b.buf[b.r:w] - b.r = w - b.changeAvailable(-len(buf)) - return buf -} - -// fill reads a new chunk into the buffer. -func (b *BufReader) fill() { - // Slide existing data to beginning. - if b.r > 0 { - copy(b.buf, b.buf[b.r:b.w]) - b.w -= b.r - b.r = 0 - } - - if b.w >= len(b.buf) { - panic("bufio: tried to fill full buffer") - } - if b.available == 0 { - b.err = io.EOF - return - } - - // Read new data: try a limited number of times. - const maxConsecutiveEmptyReads = 100 - for i := maxConsecutiveEmptyReads; i > 0; i-- { - n, err := b.read(b.buf[b.w:]) - b.w += n - if err != nil { - b.err = err - return - } - if n > 0 { - return - } - } - b.err = io.ErrNoProgress -} - -func (b *BufReader) readErr() error { - err := b.err - b.err = nil - return err -} - -func (b *BufReader) Read(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, b.readErr() - } - - if b.available != -1 { - if b.available == 0 { - return 0, io.EOF - } - if len(p) > b.available { - p = p[:b.available] - } - } - - if b.r == b.w { - if b.err != nil { - return 0, b.readErr() - } - - if len(p) >= len(b.buf) { - // Large read, empty buffer. - // Read directly into p to avoid copy. - n, err = b.read(p) - if n > 0 { - b.changeAvailable(-n) - b.lastByte = int(p[n-1]) - } - return n, err - } - - // One read. - // Do not use b.fill, which will loop. - b.r = 0 - b.w = 0 - n, b.err = b.read(b.buf) - if n == 0 { - return 0, b.readErr() - } - b.w += n - } - - // copy as much as we can - n = copy(p, b.Bytes()) - b.r += n - b.changeAvailable(-n) - b.lastByte = int(b.buf[b.r-1]) - return n, nil -} - -// ReadSlice reads until the first occurrence of delim in the input, -// returning a slice pointing at the bytes in the buffer. -// The bytes stop being valid at the next read. -// If ReadSlice encounters an error before finding a delimiter, -// it returns all the data in the buffer and the error itself (often io.EOF). -// ReadSlice fails with error ErrBufferFull if the buffer fills without a delim. -// Because the data returned from ReadSlice will be overwritten -// by the next I/O operation, most clients should use -// ReadBytes or ReadString instead. -// ReadSlice returns err != nil if and only if line does not end in delim. -func (b *BufReader) ReadSlice(delim byte) (line []byte, err error) { - for { - // Search buffer. - if i := bytes.IndexByte(b.Bytes(), delim); i >= 0 { - i++ - line = b.buf[b.r : b.r+i] - b.r += i - b.changeAvailable(-i) - break - } - - // Pending error? - if b.err != nil { - line = b.flush() - err = b.readErr() - break - } - - buffered := b.Buffered() - - // Out of available. - if b.available != -1 && buffered >= b.available { - line = b.flush() - err = io.EOF - break - } - - // Buffer full? - if buffered >= len(b.buf) { - line = b.flush() - err = bufio.ErrBufferFull - break - } - - b.fill() // buffer is not full - } - - // Handle last byte, if any. - if i := len(line) - 1; i >= 0 { - b.lastByte = int(line[i]) - } - - return line, err -} - -func (b *BufReader) ReadBytes(fn func(byte) bool) (line []byte, err error) { - for { - for i, c := range b.Bytes() { - if !fn(c) { - i-- - line = b.buf[b.r : b.r+i] //nolint - b.r += i - b.changeAvailable(-i) - break - } - } - - // Pending error? - if b.err != nil { - line = b.flush() - err = b.readErr() - break - } - - buffered := b.Buffered() - - // Out of available. - if b.available != -1 && buffered >= b.available { - line = b.flush() - err = io.EOF - break - } - - // Buffer full? - if buffered >= len(b.buf) { - line = b.flush() - err = bufio.ErrBufferFull - break - } - - b.fill() // buffer is not full - } - - // Handle last byte, if any. - if i := len(line) - 1; i >= 0 { - b.lastByte = int(line[i]) - } - - return line, err -} - -func (b *BufReader) ReadByte() (byte, error) { - if b.available == 0 { - return 0, io.EOF - } - for b.r == b.w { - if b.err != nil { - return 0, b.readErr() - } - b.fill() // buffer is empty - } - c := b.buf[b.r] - b.r++ - b.lastByte = int(c) - b.changeAvailable(-1) - return c, nil -} - -func (b *BufReader) UnreadByte() error { - if b.lastByte < 0 || b.r == 0 && b.w > 0 { - return bufio.ErrInvalidUnreadByte - } - // b.r > 0 || b.w == 0 - if b.r > 0 { - b.r-- - } else { - // b.r == 0 && b.w == 0 - b.w = 1 - } - b.buf[b.r] = byte(b.lastByte) - b.lastByte = -1 - b.changeAvailable(+1) - return nil -} - -// Discard skips the next n bytes, returning the number of bytes discarded. -// -// If Discard skips fewer than n bytes, it also returns an error. -// If 0 <= n <= b.Buffered(), Discard is guaranteed to succeed without -// reading from the underlying io.BufReader. -func (b *BufReader) Discard(n int) (discarded int, err error) { - if n < 0 { - return 0, bufio.ErrNegativeCount - } - if n == 0 { - return - } - remain := n - for { - skip := b.Buffered() - if skip == 0 { - b.fill() - skip = b.Buffered() - } - if skip > remain { - skip = remain - } - b.r += skip - b.changeAvailable(-skip) - remain -= skip - if remain == 0 { - return n, nil - } - if b.err != nil { - return n - remain, b.readErr() - } - } -} - -func (b *BufReader) ReadN(n int) (line []byte, err error) { - if n < 0 { - return nil, bufio.ErrNegativeCount - } - if n == 0 { - return - } - - nn := n - if b.available != -1 && nn > b.available { - nn = b.available - } - - for { - buffered := b.Buffered() - - if buffered >= nn { - line = b.buf[b.r : b.r+nn] - b.r += nn - b.changeAvailable(-nn) - if n > nn { - err = io.EOF - } - break - } - - // Pending error? - if b.err != nil { - line = b.flush() - err = b.readErr() - break - } - - // Buffer full? - if buffered >= len(b.buf) { - line = b.flush() - err = bufio.ErrBufferFull - break - } - - b.fill() // buffer is not full - } - - // Handle last byte, if any. - if i := len(line) - 1; i >= 0 { - b.lastByte = int(line[i]) - } - - return line, err -} - -func (b *BufReader) ReadFull() ([]byte, error) { - if b.available == -1 { - panic("not reached") - } - buf := make([]byte, b.available) - _, err := io.ReadFull(b, buf) - return buf, err -} - -func (b *BufReader) ReadFullTemp() ([]byte, error) { - if b.available == -1 { - panic("not reached") - } - if b.available <= len(b.buf) { - return b.ReadN(b.available) - } - return b.ReadFull() -} - -func (b *BufReader) read(buf []byte) (int, error) { - n, err := b.rd.Read(buf) - b.bytesRead += int64(n) - return n, err -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/reader_bytes.go b/vendor/github.com/go-pg/pg/v10/internal/pool/reader_bytes.go deleted file mode 100644 index 93646b1da..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/pool/reader_bytes.go +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package pool - -import ( - "bytes" - "errors" - "io" -) - -type BytesReader struct { - s []byte - i int -} - -func NewBytesReader(b []byte) *BytesReader { - return &BytesReader{ - s: b, - } -} - -func (r *BytesReader) Reset(b []byte) { - r.s = b - r.i = 0 -} - -func (r *BytesReader) Buffered() int { - return len(r.s) - r.i -} - -func (r *BytesReader) Bytes() []byte { - return r.s[r.i:] -} - -func (r *BytesReader) Read(b []byte) (n int, err error) { - if r.i >= len(r.s) { - return 0, io.EOF - } - n = copy(b, r.s[r.i:]) - r.i += n - return -} - -func (r *BytesReader) ReadByte() (byte, error) { - if r.i >= len(r.s) { - return 0, io.EOF - } - b := r.s[r.i] - r.i++ - return b, nil -} - -func (r *BytesReader) UnreadByte() error { - if r.i <= 0 { - return errors.New("UnreadByte: at beginning of slice") - } - r.i-- - return nil -} - -func (r *BytesReader) ReadSlice(delim byte) ([]byte, error) { - if i := bytes.IndexByte(r.s[r.i:], delim); i >= 0 { - i++ - line := r.s[r.i : r.i+i] - r.i += i - return line, nil - } - - line := r.s[r.i:] - r.i = len(r.s) - return line, io.EOF -} - -func (r *BytesReader) ReadBytes(fn func(byte) bool) ([]byte, error) { - for i, c := range r.s[r.i:] { - if !fn(c) { - i++ - line := r.s[r.i : r.i+i] - r.i += i - return line, nil - } - } - - line := r.s[r.i:] - r.i = len(r.s) - return line, io.EOF -} - -func (r *BytesReader) Discard(n int) (int, error) { - b, err := r.ReadN(n) - return len(b), err -} - -func (r *BytesReader) ReadN(n int) ([]byte, error) { - nn := n - if nn > len(r.s) { - nn = len(r.s) - } - - b := r.s[r.i : r.i+nn] - r.i += nn - if n > nn { - return b, io.EOF - } - return b, nil -} - -func (r *BytesReader) ReadFull() ([]byte, error) { - b := make([]byte, len(r.s)-r.i) - copy(b, r.s[r.i:]) - r.i = len(r.s) - return b, nil -} - -func (r *BytesReader) ReadFullTemp() ([]byte, error) { - b := r.s[r.i:] - r.i = len(r.s) - return b, nil -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/pool/write_buffer.go b/vendor/github.com/go-pg/pg/v10/internal/pool/write_buffer.go deleted file mode 100644 index 6981d3f4c..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/pool/write_buffer.go +++ /dev/null @@ -1,114 +0,0 @@ -package pool - -import ( - "encoding/binary" - "io" - "sync" -) - -const defaultBufSize = 65 << 10 // 65kb - -var wbPool = sync.Pool{ - New: func() interface{} { - return NewWriteBuffer() - }, -} - -func GetWriteBuffer() *WriteBuffer { - wb := wbPool.Get().(*WriteBuffer) - return wb -} - -func PutWriteBuffer(wb *WriteBuffer) { - wb.Reset() - wbPool.Put(wb) -} - -type WriteBuffer struct { - Bytes []byte - - msgStart int - paramStart int -} - -func NewWriteBuffer() *WriteBuffer { - return &WriteBuffer{ - Bytes: make([]byte, 0, defaultBufSize), - } -} - -func (buf *WriteBuffer) Reset() { - buf.Bytes = buf.Bytes[:0] -} - -func (buf *WriteBuffer) StartMessage(c byte) { - if c == 0 { - buf.msgStart = len(buf.Bytes) - buf.Bytes = append(buf.Bytes, 0, 0, 0, 0) - } else { - buf.msgStart = len(buf.Bytes) + 1 - buf.Bytes = append(buf.Bytes, c, 0, 0, 0, 0) - } -} - -func (buf *WriteBuffer) FinishMessage() { - binary.BigEndian.PutUint32( - buf.Bytes[buf.msgStart:], uint32(len(buf.Bytes)-buf.msgStart)) -} - -func (buf *WriteBuffer) Query() []byte { - return buf.Bytes[buf.msgStart+4 : len(buf.Bytes)-1] -} - -func (buf *WriteBuffer) StartParam() { - buf.paramStart = len(buf.Bytes) - buf.Bytes = append(buf.Bytes, 0, 0, 0, 0) -} - -func (buf *WriteBuffer) FinishParam() { - binary.BigEndian.PutUint32( - buf.Bytes[buf.paramStart:], uint32(len(buf.Bytes)-buf.paramStart-4)) -} - -var nullParamLength = int32(-1) - -func (buf *WriteBuffer) FinishNullParam() { - binary.BigEndian.PutUint32( - buf.Bytes[buf.paramStart:], uint32(nullParamLength)) -} - -func (buf *WriteBuffer) Write(b []byte) (int, error) { - buf.Bytes = append(buf.Bytes, b...) - return len(b), nil -} - -func (buf *WriteBuffer) WriteInt16(num int16) { - buf.Bytes = append(buf.Bytes, 0, 0) - binary.BigEndian.PutUint16(buf.Bytes[len(buf.Bytes)-2:], uint16(num)) -} - -func (buf *WriteBuffer) WriteInt32(num int32) { - buf.Bytes = append(buf.Bytes, 0, 0, 0, 0) - binary.BigEndian.PutUint32(buf.Bytes[len(buf.Bytes)-4:], uint32(num)) -} - -func (buf *WriteBuffer) WriteString(s string) { - buf.Bytes = append(buf.Bytes, s...) - buf.Bytes = append(buf.Bytes, 0) -} - -func (buf *WriteBuffer) WriteBytes(b []byte) { - buf.Bytes = append(buf.Bytes, b...) - buf.Bytes = append(buf.Bytes, 0) -} - -func (buf *WriteBuffer) WriteByte(c byte) error { - buf.Bytes = append(buf.Bytes, c) - return nil -} - -func (buf *WriteBuffer) ReadFrom(r io.Reader) (int64, error) { - n, err := r.Read(buf.Bytes[len(buf.Bytes):cap(buf.Bytes)]) - buf.Bytes = buf.Bytes[:len(buf.Bytes)+n] - return int64(n), err -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/strconv.go b/vendor/github.com/go-pg/pg/v10/internal/strconv.go deleted file mode 100644 index 9e42ffb03..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/strconv.go +++ /dev/null @@ -1,19 +0,0 @@ -package internal - -import "strconv" - -func Atoi(b []byte) (int, error) { - return strconv.Atoi(BytesToString(b)) -} - -func ParseInt(b []byte, base int, bitSize int) (int64, error) { - return strconv.ParseInt(BytesToString(b), base, bitSize) -} - -func ParseUint(b []byte, base int, bitSize int) (uint64, error) { - return strconv.ParseUint(BytesToString(b), base, bitSize) -} - -func ParseFloat(b []byte, bitSize int) (float64, error) { - return strconv.ParseFloat(BytesToString(b), bitSize) -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/unsafe.go b/vendor/github.com/go-pg/pg/v10/internal/unsafe.go deleted file mode 100644 index f8bc18d91..000000000 --- a/vendor/github.com/go-pg/pg/v10/internal/unsafe.go +++ /dev/null @@ -1,22 +0,0 @@ -// +build !appengine - -package internal - -import ( - "unsafe" -) - -// BytesToString converts byte slice to string. -func BytesToString(b []byte) string { - return *(*string)(unsafe.Pointer(&b)) -} - -// StringToBytes converts string to byte slice. -func StringToBytes(s string) []byte { - return *(*[]byte)(unsafe.Pointer( - &struct { - string - Cap int - }{s, len(s)}, - )) -} diff --git a/vendor/github.com/go-pg/pg/v10/listener.go b/vendor/github.com/go-pg/pg/v10/listener.go deleted file mode 100644 index d37be08d4..000000000 --- a/vendor/github.com/go-pg/pg/v10/listener.go +++ /dev/null @@ -1,414 +0,0 @@ -package pg - -import ( - "context" - "errors" - "fmt" - "strings" - "sync" - "time" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/internal/pool" - "github.com/go-pg/pg/v10/types" -) - -const gopgChannel = "gopg:ping" - -var ( - errListenerClosed = errors.New("pg: listener is closed") - errPingTimeout = errors.New("pg: ping timeout") -) - -// Notification which is received with LISTEN command. -type Notification struct { - Channel string - Payload string -} - -// Listener listens for notifications sent with NOTIFY command. -// It's NOT safe for concurrent use by multiple goroutines -// except the Channel API. -type Listener struct { - db *DB - - channels []string - - mu sync.Mutex - cn *pool.Conn - exit chan struct{} - closed bool - - chOnce sync.Once - ch chan Notification - pingCh chan struct{} -} - -func (ln *Listener) String() string { - ln.mu.Lock() - defer ln.mu.Unlock() - - return fmt.Sprintf("Listener(%s)", strings.Join(ln.channels, ", ")) -} - -func (ln *Listener) init() { - ln.exit = make(chan struct{}) -} - -func (ln *Listener) connWithLock(ctx context.Context) (*pool.Conn, error) { - ln.mu.Lock() - cn, err := ln.conn(ctx) - ln.mu.Unlock() - - switch err { - case nil: - return cn, nil - case errListenerClosed: - return nil, err - case pool.ErrClosed: - _ = ln.Close() - return nil, errListenerClosed - default: - internal.Logger.Printf(ctx, "pg: Listen failed: %s", err) - return nil, err - } -} - -func (ln *Listener) conn(ctx context.Context) (*pool.Conn, error) { - if ln.closed { - return nil, errListenerClosed - } - - if ln.cn != nil { - return ln.cn, nil - } - - cn, err := ln.db.pool.NewConn(ctx) - if err != nil { - return nil, err - } - - if err := ln.db.initConn(ctx, cn); err != nil { - _ = ln.db.pool.CloseConn(cn) - return nil, err - } - - cn.LockReader() - - if len(ln.channels) > 0 { - err := ln.listen(ctx, cn, ln.channels...) - if err != nil { - _ = ln.db.pool.CloseConn(cn) - return nil, err - } - } - - ln.cn = cn - return cn, nil -} - -func (ln *Listener) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) { - ln.mu.Lock() - if ln.cn == cn { - if isBadConn(err, allowTimeout) { - ln.reconnect(ctx, err) - } - } - ln.mu.Unlock() -} - -func (ln *Listener) reconnect(ctx context.Context, reason error) { - _ = ln.closeTheCn(reason) - _, _ = ln.conn(ctx) -} - -func (ln *Listener) closeTheCn(reason error) error { - if ln.cn == nil { - return nil - } - if !ln.closed { - internal.Logger.Printf(ln.db.ctx, "pg: discarding bad listener connection: %s", reason) - } - - err := ln.db.pool.CloseConn(ln.cn) - ln.cn = nil - return err -} - -// Close closes the listener, releasing any open resources. -func (ln *Listener) Close() error { - ln.mu.Lock() - defer ln.mu.Unlock() - - if ln.closed { - return errListenerClosed - } - ln.closed = true - close(ln.exit) - - return ln.closeTheCn(errListenerClosed) -} - -// Listen starts listening for notifications on channels. -func (ln *Listener) Listen(ctx context.Context, channels ...string) error { - // Always append channels so DB.Listen works correctly. - ln.mu.Lock() - ln.channels = appendIfNotExists(ln.channels, channels...) - ln.mu.Unlock() - - cn, err := ln.connWithLock(ctx) - if err != nil { - return err - } - - if err := ln.listen(ctx, cn, channels...); err != nil { - ln.releaseConn(ctx, cn, err, false) - return err - } - - return nil -} - -func (ln *Listener) listen(ctx context.Context, cn *pool.Conn, channels ...string) error { - err := cn.WithWriter(ctx, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - for _, channel := range channels { - if err := writeQueryMsg(wb, ln.db.fmter, "LISTEN ?", pgChan(channel)); err != nil { - return err - } - } - return nil - }) - return err -} - -// Unlisten stops listening for notifications on channels. -func (ln *Listener) Unlisten(ctx context.Context, channels ...string) error { - ln.mu.Lock() - ln.channels = removeIfExists(ln.channels, channels...) - ln.mu.Unlock() - - cn, err := ln.conn(ctx) - if err != nil { - return err - } - - if err := ln.unlisten(ctx, cn, channels...); err != nil { - ln.releaseConn(ctx, cn, err, false) - return err - } - - return nil -} - -func (ln *Listener) unlisten(ctx context.Context, cn *pool.Conn, channels ...string) error { - err := cn.WithWriter(ctx, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - for _, channel := range channels { - if err := writeQueryMsg(wb, ln.db.fmter, "UNLISTEN ?", pgChan(channel)); err != nil { - return err - } - } - return nil - }) - return err -} - -// Receive indefinitely waits for a notification. This is low-level API -// and in most cases Channel should be used instead. -func (ln *Listener) Receive(ctx context.Context) (channel string, payload string, err error) { - return ln.ReceiveTimeout(ctx, 0) -} - -// ReceiveTimeout waits for a notification until timeout is reached. -// This is low-level API and in most cases Channel should be used instead. -func (ln *Listener) ReceiveTimeout( - ctx context.Context, timeout time.Duration, -) (channel, payload string, err error) { - cn, err := ln.connWithLock(ctx) - if err != nil { - return "", "", err - } - - err = cn.WithReader(ctx, timeout, func(rd *pool.ReaderContext) error { - channel, payload, err = readNotification(rd) - return err - }) - if err != nil { - ln.releaseConn(ctx, cn, err, timeout > 0) - return "", "", err - } - - return channel, payload, nil -} - -// Channel returns a channel for concurrently receiving notifications. -// It periodically sends Ping notification to test connection health. -// -// The channel is closed with Listener. Receive* APIs can not be used -// after channel is created. -func (ln *Listener) Channel() <-chan Notification { - return ln.channel(100) -} - -// ChannelSize is like Channel, but creates a Go channel -// with specified buffer size. -func (ln *Listener) ChannelSize(size int) <-chan Notification { - return ln.channel(size) -} - -func (ln *Listener) channel(size int) <-chan Notification { - ln.chOnce.Do(func() { - ln.initChannel(size) - }) - if cap(ln.ch) != size { - err := fmt.Errorf("pg: Listener.Channel is called with different buffer size") - panic(err) - } - return ln.ch -} - -func (ln *Listener) initChannel(size int) { - const pingTimeout = time.Second - const chanSendTimeout = time.Minute - - ctx := ln.db.ctx - _ = ln.Listen(ctx, gopgChannel) - - ln.ch = make(chan Notification, size) - ln.pingCh = make(chan struct{}, 1) - - go func() { - timer := time.NewTimer(time.Minute) - timer.Stop() - - var errCount int - for { - channel, payload, err := ln.Receive(ctx) - if err != nil { - if err == errListenerClosed { - close(ln.ch) - return - } - - if errCount > 0 { - time.Sleep(500 * time.Millisecond) - } - errCount++ - - continue - } - - errCount = 0 - - // Any notification is as good as a ping. - select { - case ln.pingCh <- struct{}{}: - default: - } - - switch channel { - case gopgChannel: - // ignore - default: - timer.Reset(chanSendTimeout) - select { - case ln.ch <- Notification{channel, payload}: - if !timer.Stop() { - <-timer.C - } - case <-timer.C: - internal.Logger.Printf( - ctx, - "pg: %s channel is full for %s (notification is dropped)", - ln, - chanSendTimeout, - ) - } - } - } - }() - - go func() { - timer := time.NewTimer(time.Minute) - timer.Stop() - - healthy := true - for { - timer.Reset(pingTimeout) - select { - case <-ln.pingCh: - healthy = true - if !timer.Stop() { - <-timer.C - } - case <-timer.C: - pingErr := ln.ping() - if healthy { - healthy = false - } else { - if pingErr == nil { - pingErr = errPingTimeout - } - ln.mu.Lock() - ln.reconnect(ctx, pingErr) - ln.mu.Unlock() - } - case <-ln.exit: - return - } - } - }() -} - -func (ln *Listener) ping() error { - _, err := ln.db.Exec("NOTIFY ?", pgChan(gopgChannel)) - return err -} - -func appendIfNotExists(ss []string, es ...string) []string { -loop: - for _, e := range es { - for _, s := range ss { - if s == e { - continue loop - } - } - ss = append(ss, e) - } - return ss -} - -func removeIfExists(ss []string, es ...string) []string { - for _, e := range es { - for i, s := range ss { - if s == e { - last := len(ss) - 1 - ss[i] = ss[last] - ss = ss[:last] - break - } - } - } - return ss -} - -type pgChan string - -var _ types.ValueAppender = pgChan("") - -func (ch pgChan) AppendValue(b []byte, quote int) ([]byte, error) { - if quote == 0 { - return append(b, ch...), nil - } - - b = append(b, '"') - for _, c := range []byte(ch) { - if c == '"' { - b = append(b, '"', '"') - } else { - b = append(b, c) - } - } - b = append(b, '"') - - return b, nil -} diff --git a/vendor/github.com/go-pg/pg/v10/messages.go b/vendor/github.com/go-pg/pg/v10/messages.go deleted file mode 100644 index 7fb84ba0d..000000000 --- a/vendor/github.com/go-pg/pg/v10/messages.go +++ /dev/null @@ -1,1390 +0,0 @@ -package pg - -import ( - "bufio" - "context" - "crypto/md5" //nolint - "crypto/tls" - "encoding/binary" - "encoding/hex" - "errors" - "fmt" - "io" - "strings" - - "mellium.im/sasl" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/internal/pool" - "github.com/go-pg/pg/v10/orm" - "github.com/go-pg/pg/v10/types" -) - -// https://www.postgresql.org/docs/current/protocol-message-formats.html -const ( - commandCompleteMsg = 'C' - errorResponseMsg = 'E' - noticeResponseMsg = 'N' - parameterStatusMsg = 'S' - authenticationOKMsg = 'R' - backendKeyDataMsg = 'K' - noDataMsg = 'n' - passwordMessageMsg = 'p' - terminateMsg = 'X' - - saslInitialResponseMsg = 'p' - authenticationSASLContinueMsg = 'R' - saslResponseMsg = 'p' - authenticationSASLFinalMsg = 'R' - - authenticationOK = 0 - authenticationCleartextPassword = 3 - authenticationMD5Password = 5 - authenticationSASL = 10 - - notificationResponseMsg = 'A' - - describeMsg = 'D' - parameterDescriptionMsg = 't' - - queryMsg = 'Q' - readyForQueryMsg = 'Z' - emptyQueryResponseMsg = 'I' - rowDescriptionMsg = 'T' - dataRowMsg = 'D' - - parseMsg = 'P' - parseCompleteMsg = '1' - - bindMsg = 'B' - bindCompleteMsg = '2' - - executeMsg = 'E' - - syncMsg = 'S' - flushMsg = 'H' - - closeMsg = 'C' - closeCompleteMsg = '3' - - copyInResponseMsg = 'G' - copyOutResponseMsg = 'H' - copyDataMsg = 'd' - copyDoneMsg = 'c' -) - -var errEmptyQuery = internal.Errorf("pg: query is empty") - -func (db *baseDB) startup( - c context.Context, cn *pool.Conn, user, password, database, appName string, -) error { - err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - writeStartupMsg(wb, user, database, appName) - return nil - }) - if err != nil { - return err - } - - return cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { - for { - typ, msgLen, err := readMessageType(rd) - if err != nil { - return err - } - - switch typ { - case backendKeyDataMsg: - processID, err := readInt32(rd) - if err != nil { - return err - } - secretKey, err := readInt32(rd) - if err != nil { - return err - } - cn.ProcessID = processID - cn.SecretKey = secretKey - case parameterStatusMsg: - if err := logParameterStatus(rd, msgLen); err != nil { - return err - } - case authenticationOKMsg: - err := db.auth(c, cn, rd, user, password) - if err != nil { - return err - } - case readyForQueryMsg: - _, err := rd.ReadN(msgLen) - return err - case noticeResponseMsg: - // If we encounter a notice message from the server then we want to try to log it as it might be - // important for the client. If something goes wrong with this we want to fail. At the time of writing - // this the client will fail just encountering a notice during startup. So failing if a bad notice is - // sent is probably better than not failing, especially if we can try to log at least some data from the - // notice. - if err := db.logStartupNotice(rd); err != nil { - return err - } - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return err - } - return e - default: - return fmt.Errorf("pg: unknown startup message response: %q", typ) - } - } - }) -} - -func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Config) error { - err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - writeSSLMsg(wb) - return nil - }) - if err != nil { - return err - } - - err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { - c, err := rd.ReadByte() - if err != nil { - return err - } - if c != 'S' { - return errors.New("pg: SSL is not enabled on the server") - } - return nil - }) - if err != nil { - return err - } - - cn.SetNetConn(tls.Client(cn.NetConn(), tlsConf)) - return nil -} - -func (db *baseDB) auth( - c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, -) error { - num, err := readInt32(rd) - if err != nil { - return err - } - - switch num { - case authenticationOK: - return nil - case authenticationCleartextPassword: - return db.authCleartext(c, cn, rd, password) - case authenticationMD5Password: - return db.authMD5(c, cn, rd, user, password) - case authenticationSASL: - return db.authSASL(c, cn, rd, user, password) - default: - return fmt.Errorf("pg: unknown authentication message response: %q", num) - } -} - -// logStartupNotice will handle notice messages during the startup process. It will parse them and log them for the -// client. Notices are not common and only happen if there is something the client should be aware of. So logging should -// not be a problem. -// Notice messages can be seen in startup: https://www.postgresql.org/docs/13/protocol-flow.html -// Information on the notice message format: https://www.postgresql.org/docs/13/protocol-message-formats.html -// Note: This is true for earlier versions of PostgreSQL as well, I've just included the latest versions of the docs. -func (db *baseDB) logStartupNotice( - rd *pool.ReaderContext, -) error { - message := make([]string, 0) - // Notice messages are null byte delimited key-value pairs. Where the keys are one byte. - for { - // Read the key byte. - fieldType, err := rd.ReadByte() - if err != nil { - return err - } - - // If they key byte (the type of field this data is) is 0 then that means we have reached the end of the notice. - // We can break our loop here and throw our message data into the logger. - if fieldType == 0 { - break - } - - // Read until the next null byte to get the data for this field. This does include the null byte at the end of - // fieldValue so we will trim it off down below. - fieldValue, err := readString(rd) - if err != nil { - return err - } - - // Just throw the field type as a string and its value into an array. - // Field types can be seen here: https://www.postgresql.org/docs/13/protocol-error-fields.html - // TODO This is a rare occurrence as is, would it be worth adding something to indicate what the field names - // are? Or is PostgreSQL documentation enough for a user at this point? - message = append(message, fmt.Sprintf("%s: %s", string(fieldType), fieldValue)) - } - - // Tell the client what PostgreSQL told us. Warning because its probably something the client should at the very - // least adjust. - internal.Warn.Printf("notice during startup: %s", strings.Join(message, ", ")) - - return nil -} - -func (db *baseDB) authCleartext( - c context.Context, cn *pool.Conn, rd *pool.ReaderContext, password string, -) error { - err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - writePasswordMsg(wb, password) - return nil - }) - if err != nil { - return err - } - return readAuthOK(rd) -} - -func (db *baseDB) authMD5( - c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, -) error { - b, err := rd.ReadN(4) - if err != nil { - return err - } - - secret := "md5" + md5s(md5s(password+user)+string(b)) - err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - writePasswordMsg(wb, secret) - return nil - }) - if err != nil { - return err - } - - return readAuthOK(rd) -} - -func readAuthOK(rd *pool.ReaderContext) error { - c, _, err := readMessageType(rd) - if err != nil { - return err - } - - switch c { - case authenticationOKMsg: - c0, err := readInt32(rd) - if err != nil { - return err - } - if c0 != 0 { - return fmt.Errorf("pg: unexpected authentication code: %q", c0) - } - return nil - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return err - } - return e - default: - return fmt.Errorf("pg: unknown password message response: %q", c) - } -} - -func (db *baseDB) authSASL( - c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, -) error { - s, err := readString(rd) - if err != nil { - return err - } - if s != "SCRAM-SHA-256" { - return fmt.Errorf("pg: SASL: got %q, wanted %q", s, "SCRAM-SHA-256") - } - - c0, err := rd.ReadByte() - if err != nil { - return err - } - if c0 != 0 { - return fmt.Errorf("pg: SASL: got %q, wanted %q", c0, 0) - } - - creds := sasl.Credentials(func() (Username, Password, Identity []byte) { - return []byte(user), []byte(password), nil - }) - client := sasl.NewClient(sasl.ScramSha256, creds) - - _, resp, err := client.Step(nil) - if err != nil { - return err - } - - err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - wb.StartMessage(saslInitialResponseMsg) - wb.WriteString("SCRAM-SHA-256") - wb.WriteInt32(int32(len(resp))) - _, err := wb.Write(resp) - if err != nil { - return err - } - wb.FinishMessage() - return nil - }) - if err != nil { - return err - } - - typ, n, err := readMessageType(rd) - if err != nil { - return err - } - - switch typ { - case authenticationSASLContinueMsg: - c11, err := readInt32(rd) - if err != nil { - return err - } - if c11 != 11 { - return fmt.Errorf("pg: SASL: got %q, wanted %q", typ, 11) - } - - b, err := rd.ReadN(n - 4) - if err != nil { - return err - } - - _, resp, err = client.Step(b) - if err != nil { - return err - } - - err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - wb.StartMessage(saslResponseMsg) - _, err := wb.Write(resp) - if err != nil { - return err - } - wb.FinishMessage() - return nil - }) - if err != nil { - return err - } - - return readAuthSASLFinal(rd, client) - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return err - } - return e - default: - return fmt.Errorf( - "pg: SASL: got %q, wanted %q", typ, authenticationSASLContinueMsg) - } -} - -func readAuthSASLFinal(rd *pool.ReaderContext, client *sasl.Negotiator) error { - c, n, err := readMessageType(rd) - if err != nil { - return err - } - - switch c { - case authenticationSASLFinalMsg: - c12, err := readInt32(rd) - if err != nil { - return err - } - if c12 != 12 { - return fmt.Errorf("pg: SASL: got %q, wanted %q", c, 12) - } - - b, err := rd.ReadN(n - 4) - if err != nil { - return err - } - - _, _, err = client.Step(b) - if err != nil { - return err - } - - if client.State() != sasl.ValidServerResponse { - return fmt.Errorf("pg: SASL: state=%q, wanted %q", - client.State(), sasl.ValidServerResponse) - } - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return err - } - return e - default: - return fmt.Errorf( - "pg: SASL: got %q, wanted %q", c, authenticationSASLFinalMsg) - } - - return readAuthOK(rd) -} - -func md5s(s string) string { - //nolint - h := md5.Sum([]byte(s)) - return hex.EncodeToString(h[:]) -} - -func writeStartupMsg(buf *pool.WriteBuffer, user, database, appName string) { - buf.StartMessage(0) - buf.WriteInt32(196608) - buf.WriteString("user") - buf.WriteString(user) - buf.WriteString("database") - buf.WriteString(database) - if appName != "" { - buf.WriteString("application_name") - buf.WriteString(appName) - } - buf.WriteString("") - buf.FinishMessage() -} - -func writeSSLMsg(buf *pool.WriteBuffer) { - buf.StartMessage(0) - buf.WriteInt32(80877103) - buf.FinishMessage() -} - -func writePasswordMsg(buf *pool.WriteBuffer, password string) { - buf.StartMessage(passwordMessageMsg) - buf.WriteString(password) - buf.FinishMessage() -} - -func writeFlushMsg(buf *pool.WriteBuffer) { - buf.StartMessage(flushMsg) - buf.FinishMessage() -} - -func writeCancelRequestMsg(buf *pool.WriteBuffer, processID, secretKey int32) { - buf.StartMessage(0) - buf.WriteInt32(80877102) - buf.WriteInt32(processID) - buf.WriteInt32(secretKey) - buf.FinishMessage() -} - -func writeQueryMsg( - buf *pool.WriteBuffer, - fmter orm.QueryFormatter, - query interface{}, - params ...interface{}, -) error { - buf.StartMessage(queryMsg) - bytes, err := appendQuery(fmter, buf.Bytes, query, params...) - if err != nil { - return err - } - buf.Bytes = bytes - err = buf.WriteByte(0x0) - if err != nil { - return err - } - buf.FinishMessage() - return nil -} - -func appendQuery(fmter orm.QueryFormatter, dst []byte, query interface{}, params ...interface{}) ([]byte, error) { - switch query := query.(type) { - case orm.QueryAppender: - if v, ok := fmter.(*orm.Formatter); ok { - fmter = v.WithModel(query) - } - return query.AppendQuery(fmter, dst) - case string: - if len(params) > 0 { - model, ok := params[len(params)-1].(orm.TableModel) - if ok { - if v, ok := fmter.(*orm.Formatter); ok { - fmter = v.WithTableModel(model) - params = params[:len(params)-1] - } - } - } - return fmter.FormatQuery(dst, query, params...), nil - default: - return nil, fmt.Errorf("pg: can't append %T", query) - } -} - -func writeSyncMsg(buf *pool.WriteBuffer) { - buf.StartMessage(syncMsg) - buf.FinishMessage() -} - -func writeParseDescribeSyncMsg(buf *pool.WriteBuffer, name, q string) { - buf.StartMessage(parseMsg) - buf.WriteString(name) - buf.WriteString(q) - buf.WriteInt16(0) - buf.FinishMessage() - - buf.StartMessage(describeMsg) - buf.WriteByte('S') //nolint - buf.WriteString(name) - buf.FinishMessage() - - writeSyncMsg(buf) -} - -func readParseDescribeSync(rd *pool.ReaderContext) ([]types.ColumnInfo, error) { - var columns []types.ColumnInfo - var firstErr error - for { - c, msgLen, err := readMessageType(rd) - if err != nil { - return nil, err - } - switch c { - case parseCompleteMsg: - _, err = rd.ReadN(msgLen) - if err != nil { - return nil, err - } - case rowDescriptionMsg: // Response to the DESCRIBE message. - columns, err = readRowDescription(rd, pool.NewColumnAlloc()) - if err != nil { - return nil, err - } - case parameterDescriptionMsg: // Response to the DESCRIBE message. - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - case noDataMsg: // Response to the DESCRIBE message. - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - case readyForQueryMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if firstErr != nil { - return nil, firstErr - } - return columns, err - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return nil, err - } - if firstErr == nil { - firstErr = e - } - case noticeResponseMsg: - if err := logNotice(rd, msgLen); err != nil { - return nil, err - } - case parameterStatusMsg: - if err := logParameterStatus(rd, msgLen); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("pg: readParseDescribeSync: unexpected message %q", c) - } - } -} - -// Writes BIND, EXECUTE and SYNC messages. -func writeBindExecuteMsg(buf *pool.WriteBuffer, name string, params ...interface{}) error { - buf.StartMessage(bindMsg) - buf.WriteString("") - buf.WriteString(name) - buf.WriteInt16(0) - buf.WriteInt16(int16(len(params))) - for _, param := range params { - buf.StartParam() - bytes := types.Append(buf.Bytes, param, 0) - if bytes != nil { - buf.Bytes = bytes - buf.FinishParam() - } else { - buf.FinishNullParam() - } - } - buf.WriteInt16(0) - buf.FinishMessage() - - buf.StartMessage(executeMsg) - buf.WriteString("") - buf.WriteInt32(0) - buf.FinishMessage() - - writeSyncMsg(buf) - - return nil -} - -func writeCloseMsg(buf *pool.WriteBuffer, name string) { - buf.StartMessage(closeMsg) - buf.WriteByte('S') //nolint - buf.WriteString(name) - buf.FinishMessage() -} - -func readCloseCompleteMsg(rd *pool.ReaderContext) error { - for { - c, msgLen, err := readMessageType(rd) - if err != nil { - return err - } - switch c { - case closeCompleteMsg: - _, err := rd.ReadN(msgLen) - return err - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return err - } - return e - case noticeResponseMsg: - if err := logNotice(rd, msgLen); err != nil { - return err - } - case parameterStatusMsg: - if err := logParameterStatus(rd, msgLen); err != nil { - return err - } - default: - return fmt.Errorf("pg: readCloseCompleteMsg: unexpected message %q", c) - } - } -} - -func readSimpleQuery(rd *pool.ReaderContext) (*result, error) { - var res result - var firstErr error - for { - c, msgLen, err := readMessageType(rd) - if err != nil { - return nil, err - } - - switch c { - case commandCompleteMsg: - b, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if err := res.parse(b); err != nil && firstErr == nil { - firstErr = err - } - case readyForQueryMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if firstErr != nil { - return nil, firstErr - } - return &res, nil - case rowDescriptionMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - case dataRowMsg: - if _, err := rd.Discard(msgLen); err != nil { - return nil, err - } - res.returned++ - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return nil, err - } - if firstErr == nil { - firstErr = e - } - case emptyQueryResponseMsg: - if firstErr == nil { - firstErr = errEmptyQuery - } - case noticeResponseMsg: - if err := logNotice(rd, msgLen); err != nil { - return nil, err - } - case parameterStatusMsg: - if err := logParameterStatus(rd, msgLen); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("pg: readSimpleQuery: unexpected message %q", c) - } - } -} - -func readExtQuery(rd *pool.ReaderContext) (*result, error) { - var res result - var firstErr error - for { - c, msgLen, err := readMessageType(rd) - if err != nil { - return nil, err - } - - switch c { - case bindCompleteMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - case dataRowMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - res.returned++ - case commandCompleteMsg: // Response to the EXECUTE message. - b, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if err := res.parse(b); err != nil && firstErr == nil { - firstErr = err - } - case readyForQueryMsg: // Response to the SYNC message. - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if firstErr != nil { - return nil, firstErr - } - return &res, nil - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return nil, err - } - if firstErr == nil { - firstErr = e - } - case emptyQueryResponseMsg: - if firstErr == nil { - firstErr = errEmptyQuery - } - case noticeResponseMsg: - if err := logNotice(rd, msgLen); err != nil { - return nil, err - } - case parameterStatusMsg: - if err := logParameterStatus(rd, msgLen); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("pg: readExtQuery: unexpected message %q", c) - } - } -} - -func readRowDescription( - rd *pool.ReaderContext, columnAlloc *pool.ColumnAlloc, -) ([]types.ColumnInfo, error) { - numCol, err := readInt16(rd) - if err != nil { - return nil, err - } - - for i := 0; i < int(numCol); i++ { - b, err := rd.ReadSlice(0) - if err != nil { - return nil, err - } - - col := columnAlloc.New(int16(i), b[:len(b)-1]) - - if _, err := rd.ReadN(6); err != nil { - return nil, err - } - - dataType, err := readInt32(rd) - if err != nil { - return nil, err - } - col.DataType = dataType - - if _, err := rd.ReadN(8); err != nil { - return nil, err - } - } - - return columnAlloc.Columns(), nil -} - -func readDataRow( - ctx context.Context, - rd *pool.ReaderContext, - columns []types.ColumnInfo, - scanner orm.ColumnScanner, -) error { - numCol, err := readInt16(rd) - if err != nil { - return err - } - - if h, ok := scanner.(orm.BeforeScanHook); ok { - if err := h.BeforeScan(ctx); err != nil { - return err - } - } - - var firstErr error - - for colIdx := int16(0); colIdx < numCol; colIdx++ { - n, err := readInt32(rd) - if err != nil { - return err - } - - var colRd types.Reader - if int(n) <= rd.Buffered() { - colRd = rd.BytesReader(int(n)) - } else { - rd.SetAvailable(int(n)) - colRd = rd - } - - column := columns[colIdx] - if err := scanner.ScanColumn(column, colRd, int(n)); err != nil && firstErr == nil { - firstErr = internal.Errorf(err.Error()) - } - - if rd == colRd { - if rd.Available() > 0 { - if _, err := rd.Discard(rd.Available()); err != nil && firstErr == nil { - firstErr = err - } - } - rd.SetAvailable(-1) - } - } - - if h, ok := scanner.(orm.AfterScanHook); ok { - if err := h.AfterScan(ctx); err != nil { - return err - } - } - - return firstErr -} - -func newModel(mod interface{}) (orm.Model, error) { - m, err := orm.NewModel(mod) - if err != nil { - return nil, err - } - return m, m.Init() -} - -func readSimpleQueryData( - ctx context.Context, rd *pool.ReaderContext, mod interface{}, -) (*result, error) { - var columns []types.ColumnInfo - var res result - var firstErr error - for { - c, msgLen, err := readMessageType(rd) - if err != nil { - return nil, err - } - - switch c { - case rowDescriptionMsg: - columns, err = readRowDescription(rd, rd.ColumnAlloc) - if err != nil { - return nil, err - } - - if res.model == nil { - var err error - res.model, err = newModel(mod) - if err != nil { - if firstErr == nil { - firstErr = err - } - res.model = Discard - } - } - case dataRowMsg: - scanner := res.model.NextColumnScanner() - if err := readDataRow(ctx, rd, columns, scanner); err != nil { - if firstErr == nil { - firstErr = err - } - } else if err := res.model.AddColumnScanner(scanner); err != nil { - if firstErr == nil { - firstErr = err - } - } - - res.returned++ - case commandCompleteMsg: - b, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if err := res.parse(b); err != nil && firstErr == nil { - firstErr = err - } - case readyForQueryMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if firstErr != nil { - return nil, firstErr - } - return &res, nil - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return nil, err - } - if firstErr == nil { - firstErr = e - } - case emptyQueryResponseMsg: - if firstErr == nil { - firstErr = errEmptyQuery - } - case noticeResponseMsg: - if err := logNotice(rd, msgLen); err != nil { - return nil, err - } - case parameterStatusMsg: - if err := logParameterStatus(rd, msgLen); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("pg: readSimpleQueryData: unexpected message %q", c) - } - } -} - -func readExtQueryData( - ctx context.Context, rd *pool.ReaderContext, mod interface{}, columns []types.ColumnInfo, -) (*result, error) { - var res result - var firstErr error - for { - c, msgLen, err := readMessageType(rd) - if err != nil { - return nil, err - } - - switch c { - case bindCompleteMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - case dataRowMsg: - if res.model == nil { - var err error - res.model, err = newModel(mod) - if err != nil { - if firstErr == nil { - firstErr = err - } - res.model = Discard - } - } - - scanner := res.model.NextColumnScanner() - if err := readDataRow(ctx, rd, columns, scanner); err != nil { - if firstErr == nil { - firstErr = err - } - } else if err := res.model.AddColumnScanner(scanner); err != nil { - if firstErr == nil { - firstErr = err - } - } - - res.returned++ - case commandCompleteMsg: // Response to the EXECUTE message. - b, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if err := res.parse(b); err != nil && firstErr == nil { - firstErr = err - } - case readyForQueryMsg: // Response to the SYNC message. - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if firstErr != nil { - return nil, firstErr - } - return &res, nil - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return nil, err - } - if firstErr == nil { - firstErr = e - } - case noticeResponseMsg: - if err := logNotice(rd, msgLen); err != nil { - return nil, err - } - case parameterStatusMsg: - if err := logParameterStatus(rd, msgLen); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("pg: readExtQueryData: unexpected message %q", c) - } - } -} - -func readCopyInResponse(rd *pool.ReaderContext) error { - var firstErr error - for { - c, msgLen, err := readMessageType(rd) - if err != nil { - return err - } - - switch c { - case copyInResponseMsg: - _, err := rd.ReadN(msgLen) - return err - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return err - } - if firstErr == nil { - firstErr = e - } - case readyForQueryMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return err - } - return firstErr - case noticeResponseMsg: - if err := logNotice(rd, msgLen); err != nil { - return err - } - case parameterStatusMsg: - if err := logParameterStatus(rd, msgLen); err != nil { - return err - } - default: - return fmt.Errorf("pg: readCopyInResponse: unexpected message %q", c) - } - } -} - -func readCopyOutResponse(rd *pool.ReaderContext) error { - var firstErr error - for { - c, msgLen, err := readMessageType(rd) - if err != nil { - return err - } - - switch c { - case copyOutResponseMsg: - _, err := rd.ReadN(msgLen) - return err - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return err - } - if firstErr == nil { - firstErr = e - } - case readyForQueryMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return err - } - return firstErr - case noticeResponseMsg: - if err := logNotice(rd, msgLen); err != nil { - return err - } - case parameterStatusMsg: - if err := logParameterStatus(rd, msgLen); err != nil { - return err - } - default: - return fmt.Errorf("pg: readCopyOutResponse: unexpected message %q", c) - } - } -} - -func readCopyData(rd *pool.ReaderContext, w io.Writer) (*result, error) { - var res result - var firstErr error - for { - c, msgLen, err := readMessageType(rd) - if err != nil { - return nil, err - } - - switch c { - case copyDataMsg: - for msgLen > 0 { - b, err := rd.ReadN(msgLen) - if err != nil && err != bufio.ErrBufferFull { - return nil, err - } - - _, err = w.Write(b) - if err != nil { - return nil, err - } - - msgLen -= len(b) - } - case copyDoneMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - case commandCompleteMsg: - b, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if err := res.parse(b); err != nil && firstErr == nil { - firstErr = err - } - case readyForQueryMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if firstErr != nil { - return nil, firstErr - } - return &res, nil - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return nil, err - } - return nil, e - case noticeResponseMsg: - if err := logNotice(rd, msgLen); err != nil { - return nil, err - } - case parameterStatusMsg: - if err := logParameterStatus(rd, msgLen); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("pg: readCopyData: unexpected message %q", c) - } - } -} - -func writeCopyData(buf *pool.WriteBuffer, r io.Reader) error { - buf.StartMessage(copyDataMsg) - _, err := buf.ReadFrom(r) - buf.FinishMessage() - return err -} - -func writeCopyDone(buf *pool.WriteBuffer) { - buf.StartMessage(copyDoneMsg) - buf.FinishMessage() -} - -func readReadyForQuery(rd *pool.ReaderContext) (*result, error) { - var res result - var firstErr error - for { - c, msgLen, err := readMessageType(rd) - if err != nil { - return nil, err - } - - switch c { - case commandCompleteMsg: - b, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if err := res.parse(b); err != nil && firstErr == nil { - firstErr = err - } - case readyForQueryMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return nil, err - } - if firstErr != nil { - return nil, firstErr - } - return &res, nil - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return nil, err - } - if firstErr == nil { - firstErr = e - } - case noticeResponseMsg: - if err := logNotice(rd, msgLen); err != nil { - return nil, err - } - case parameterStatusMsg: - if err := logParameterStatus(rd, msgLen); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("pg: readReadyForQueryOrError: unexpected message %q", c) - } - } -} - -func readNotification(rd *pool.ReaderContext) (channel, payload string, err error) { - for { - c, msgLen, err := readMessageType(rd) - if err != nil { - return "", "", err - } - - switch c { - case commandCompleteMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return "", "", err - } - case readyForQueryMsg: - _, err := rd.ReadN(msgLen) - if err != nil { - return "", "", err - } - case errorResponseMsg: - e, err := readError(rd) - if err != nil { - return "", "", err - } - return "", "", e - case noticeResponseMsg: - if err := logNotice(rd, msgLen); err != nil { - return "", "", err - } - case notificationResponseMsg: - _, err := readInt32(rd) - if err != nil { - return "", "", err - } - channel, err = readString(rd) - if err != nil { - return "", "", err - } - payload, err = readString(rd) - if err != nil { - return "", "", err - } - return channel, payload, nil - default: - return "", "", fmt.Errorf("pg: readNotification: unexpected message %q", c) - } - } -} - -var terminateMessage = []byte{terminateMsg, 0, 0, 0, 4} - -func terminateConn(cn *pool.Conn) error { - // Don't use cn.Buf because it is racy with user code. - _, err := cn.NetConn().Write(terminateMessage) - return err -} - -//------------------------------------------------------------------------------ - -func logNotice(rd *pool.ReaderContext, msgLen int) error { - _, err := rd.ReadN(msgLen) - return err -} - -func logParameterStatus(rd *pool.ReaderContext, msgLen int) error { - _, err := rd.ReadN(msgLen) - return err -} - -func readInt16(rd *pool.ReaderContext) (int16, error) { - b, err := rd.ReadN(2) - if err != nil { - return 0, err - } - return int16(binary.BigEndian.Uint16(b)), nil -} - -func readInt32(rd *pool.ReaderContext) (int32, error) { - b, err := rd.ReadN(4) - if err != nil { - return 0, err - } - return int32(binary.BigEndian.Uint32(b)), nil -} - -func readString(rd *pool.ReaderContext) (string, error) { - b, err := rd.ReadSlice(0) - if err != nil { - return "", err - } - return string(b[:len(b)-1]), nil -} - -func readError(rd *pool.ReaderContext) (error, error) { - m := make(map[byte]string) - for { - c, err := rd.ReadByte() - if err != nil { - return nil, err - } - if c == 0 { - break - } - s, err := readString(rd) - if err != nil { - return nil, err - } - m[c] = s - } - return internal.NewPGError(m), nil -} - -func readMessageType(rd *pool.ReaderContext) (byte, int, error) { - c, err := rd.ReadByte() - if err != nil { - return 0, 0, err - } - l, err := readInt32(rd) - if err != nil { - return 0, 0, err - } - return c, int(l) - 4, nil -} diff --git a/vendor/github.com/go-pg/pg/v10/options.go b/vendor/github.com/go-pg/pg/v10/options.go deleted file mode 100644 index efd634fd2..000000000 --- a/vendor/github.com/go-pg/pg/v10/options.go +++ /dev/null @@ -1,277 +0,0 @@ -package pg - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "net" - "net/url" - "os" - "runtime" - "strconv" - "strings" - "time" - - "github.com/go-pg/pg/v10/internal/pool" -) - -// Options contains database connection options. -type Options struct { - // Network type, either tcp or unix. - // Default is tcp. - Network string - // TCP host:port or Unix socket depending on Network. - Addr string - - // Dialer creates new network connection and has priority over - // Network and Addr options. - Dialer func(ctx context.Context, network, addr string) (net.Conn, error) - - // Hook that is called after new connection is established - // and user is authenticated. - OnConnect func(ctx context.Context, cn *Conn) error - - User string - Password string - Database string - - // ApplicationName is the application name. Used in logs on Pg side. - // Only available from pg-9.0. - ApplicationName string - - // TLS config for secure connections. - TLSConfig *tls.Config - - // Dial timeout for establishing new connections. - // Default is 5 seconds. - DialTimeout time.Duration - - // Timeout for socket reads. If reached, commands will fail - // with a timeout instead of blocking. - ReadTimeout time.Duration - // Timeout for socket writes. If reached, commands will fail - // with a timeout instead of blocking. - WriteTimeout time.Duration - - // Maximum number of retries before giving up. - // Default is to not retry failed queries. - MaxRetries int - // Whether to retry queries cancelled because of statement_timeout. - RetryStatementTimeout bool - // Minimum backoff between each retry. - // Default is 250 milliseconds; -1 disables backoff. - MinRetryBackoff time.Duration - // Maximum backoff between each retry. - // Default is 4 seconds; -1 disables backoff. - MaxRetryBackoff time.Duration - - // Maximum number of socket connections. - // Default is 10 connections per every CPU as reported by runtime.NumCPU. - PoolSize int - // Minimum number of idle connections which is useful when establishing - // new connection is slow. - MinIdleConns int - // Connection age at which client retires (closes) the connection. - // It is useful with proxies like PgBouncer and HAProxy. - // Default is to not close aged connections. - MaxConnAge time.Duration - // Time for which client waits for free connection if all - // connections are busy before returning an error. - // Default is 30 seconds if ReadTimeOut is not defined, otherwise, - // ReadTimeout + 1 second. - PoolTimeout time.Duration - // Amount of time after which client closes idle connections. - // Should be less than server's timeout. - // Default is 5 minutes. -1 disables idle timeout check. - IdleTimeout time.Duration - // Frequency of idle checks made by idle connections reaper. - // Default is 1 minute. -1 disables idle connections reaper, - // but idle connections are still discarded by the client - // if IdleTimeout is set. - IdleCheckFrequency time.Duration -} - -func (opt *Options) init() { - if opt.Network == "" { - opt.Network = "tcp" - } - - if opt.Addr == "" { - switch opt.Network { - case "tcp": - host := env("PGHOST", "localhost") - port := env("PGPORT", "5432") - opt.Addr = fmt.Sprintf("%s:%s", host, port) - case "unix": - opt.Addr = "/var/run/postgresql/.s.PGSQL.5432" - } - } - - if opt.DialTimeout == 0 { - opt.DialTimeout = 5 * time.Second - } - if opt.Dialer == nil { - opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { - netDialer := &net.Dialer{ - Timeout: opt.DialTimeout, - KeepAlive: 5 * time.Minute, - } - return netDialer.DialContext(ctx, network, addr) - } - } - - if opt.User == "" { - opt.User = env("PGUSER", "postgres") - } - - if opt.Database == "" { - opt.Database = env("PGDATABASE", "postgres") - } - - if opt.PoolSize == 0 { - opt.PoolSize = 10 * runtime.NumCPU() - } - - if opt.PoolTimeout == 0 { - if opt.ReadTimeout != 0 { - opt.PoolTimeout = opt.ReadTimeout + time.Second - } else { - opt.PoolTimeout = 30 * time.Second - } - } - - if opt.IdleTimeout == 0 { - opt.IdleTimeout = 5 * time.Minute - } - if opt.IdleCheckFrequency == 0 { - opt.IdleCheckFrequency = time.Minute - } - - switch opt.MinRetryBackoff { - case -1: - opt.MinRetryBackoff = 0 - case 0: - opt.MinRetryBackoff = 250 * time.Millisecond - } - switch opt.MaxRetryBackoff { - case -1: - opt.MaxRetryBackoff = 0 - case 0: - opt.MaxRetryBackoff = 4 * time.Second - } -} - -func env(key, defValue string) string { - envValue := os.Getenv(key) - if envValue != "" { - return envValue - } - return defValue -} - -// ParseURL parses an URL into options that can be used to connect to PostgreSQL. -func ParseURL(sURL string) (*Options, error) { - parsedURL, err := url.Parse(sURL) - if err != nil { - return nil, err - } - - // scheme - if parsedURL.Scheme != "postgres" && parsedURL.Scheme != "postgresql" { - return nil, errors.New("pg: invalid scheme: " + parsedURL.Scheme) - } - - // host and port - options := &Options{ - Addr: parsedURL.Host, - } - if !strings.Contains(options.Addr, ":") { - options.Addr += ":5432" - } - - // username and password - if parsedURL.User != nil { - options.User = parsedURL.User.Username() - - if password, ok := parsedURL.User.Password(); ok { - options.Password = password - } - } - - if options.User == "" { - options.User = "postgres" - } - - // database - if len(strings.Trim(parsedURL.Path, "/")) > 0 { - options.Database = parsedURL.Path[1:] - } else { - return nil, errors.New("pg: database name not provided") - } - - // ssl mode - query, err := url.ParseQuery(parsedURL.RawQuery) - if err != nil { - return nil, err - } - - if sslMode, ok := query["sslmode"]; ok && len(sslMode) > 0 { - switch sslMode[0] { - case "verify-ca", "verify-full": - options.TLSConfig = &tls.Config{} - case "allow", "prefer", "require": - options.TLSConfig = &tls.Config{InsecureSkipVerify: true} //nolint - case "disable": - options.TLSConfig = nil - default: - return nil, fmt.Errorf("pg: sslmode '%v' is not supported", sslMode[0]) - } - } else { - options.TLSConfig = &tls.Config{InsecureSkipVerify: true} //nolint - } - - delete(query, "sslmode") - - if appName, ok := query["application_name"]; ok && len(appName) > 0 { - options.ApplicationName = appName[0] - } - - delete(query, "application_name") - - if connTimeout, ok := query["connect_timeout"]; ok && len(connTimeout) > 0 { - ct, err := strconv.Atoi(connTimeout[0]) - if err != nil { - return nil, fmt.Errorf("pg: cannot parse connect_timeout option as int") - } - options.DialTimeout = time.Second * time.Duration(ct) - } - - delete(query, "connect_timeout") - - if len(query) > 0 { - return nil, errors.New("pg: options other than 'sslmode', 'application_name' and 'connect_timeout' are not supported") - } - - return options, nil -} - -func (opt *Options) getDialer() func(context.Context) (net.Conn, error) { - return func(ctx context.Context) (net.Conn, error) { - return opt.Dialer(ctx, opt.Network, opt.Addr) - } -} - -func newConnPool(opt *Options) *pool.ConnPool { - return pool.NewConnPool(&pool.Options{ - Dialer: opt.getDialer(), - OnClose: terminateConn, - - PoolSize: opt.PoolSize, - MinIdleConns: opt.MinIdleConns, - MaxConnAge: opt.MaxConnAge, - PoolTimeout: opt.PoolTimeout, - IdleTimeout: opt.IdleTimeout, - IdleCheckFrequency: opt.IdleCheckFrequency, - }) -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite.go b/vendor/github.com/go-pg/pg/v10/orm/composite.go deleted file mode 100644 index d2e48a8b3..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/composite.go +++ /dev/null @@ -1,100 +0,0 @@ -package orm - -import ( - "fmt" - "reflect" - - "github.com/go-pg/pg/v10/internal/pool" - "github.com/go-pg/pg/v10/types" -) - -func compositeScanner(typ reflect.Type) types.ScannerFunc { - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - var table *Table - return func(v reflect.Value, rd types.Reader, n int) error { - if n == -1 { - v.Set(reflect.Zero(v.Type())) - return nil - } - - if table == nil { - table = GetTable(typ) - } - if v.Kind() == reflect.Ptr { - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - v = v.Elem() - } - - p := newCompositeParser(rd) - var elemReader *pool.BytesReader - - var firstErr error - for i := 0; ; i++ { - elem, err := p.NextElem() - if err != nil { - if err == errEndOfComposite { - break - } - return err - } - - if i >= len(table.Fields) { - if firstErr == nil { - firstErr = fmt.Errorf( - "pg: %s has %d fields, but composite requires at least %d values", - table, len(table.Fields), i) - } - continue - } - - if elemReader == nil { - elemReader = pool.NewBytesReader(elem) - } else { - elemReader.Reset(elem) - } - - field := table.Fields[i] - if elem == nil { - err = field.ScanValue(v, elemReader, -1) - } else { - err = field.ScanValue(v, elemReader, len(elem)) - } - if err != nil && firstErr == nil { - firstErr = err - } - } - - return firstErr - } -} - -func compositeAppender(typ reflect.Type) types.AppenderFunc { - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - var table *Table - return func(b []byte, v reflect.Value, quote int) []byte { - if table == nil { - table = GetTable(typ) - } - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - - b = append(b, "ROW("...) - for i, f := range table.Fields { - if i > 0 { - b = append(b, ',') - } - b = f.AppendValue(b, v, quote) - } - b = append(b, ')') - return b - } -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite_create.go b/vendor/github.com/go-pg/pg/v10/orm/composite_create.go deleted file mode 100644 index fd60a94e4..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/composite_create.go +++ /dev/null @@ -1,89 +0,0 @@ -package orm - -import ( - "strconv" -) - -type CreateCompositeOptions struct { - Varchar int // replaces PostgreSQL data type `text` with `varchar(n)` -} - -type CreateCompositeQuery struct { - q *Query - opt *CreateCompositeOptions -} - -var ( - _ QueryAppender = (*CreateCompositeQuery)(nil) - _ QueryCommand = (*CreateCompositeQuery)(nil) -) - -func NewCreateCompositeQuery(q *Query, opt *CreateCompositeOptions) *CreateCompositeQuery { - return &CreateCompositeQuery{ - q: q, - opt: opt, - } -} - -func (q *CreateCompositeQuery) String() string { - b, err := q.AppendQuery(defaultFmter, nil) - if err != nil { - panic(err) - } - return string(b) -} - -func (q *CreateCompositeQuery) Operation() QueryOp { - return CreateCompositeOp -} - -func (q *CreateCompositeQuery) Clone() QueryCommand { - return &CreateCompositeQuery{ - q: q.q.Clone(), - opt: q.opt, - } -} - -func (q *CreateCompositeQuery) Query() *Query { - return q.q -} - -func (q *CreateCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { - return q.AppendQuery(dummyFormatter{}, b) -} - -func (q *CreateCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { - if q.q.stickyErr != nil { - return nil, q.q.stickyErr - } - if q.q.tableModel == nil { - return nil, errModelNil - } - - table := q.q.tableModel.Table() - - b = append(b, "CREATE TYPE "...) - b = append(b, table.Alias...) - b = append(b, " AS ("...) - - for i, field := range table.Fields { - if i > 0 { - b = append(b, ", "...) - } - - b = append(b, field.Column...) - b = append(b, " "...) - if field.UserSQLType == "" && q.opt != nil && q.opt.Varchar > 0 && - field.SQLType == "text" { - b = append(b, "varchar("...) - b = strconv.AppendInt(b, int64(q.opt.Varchar), 10) - b = append(b, ")"...) - } else { - b = append(b, field.SQLType...) - } - } - - b = append(b, ")"...) - - return b, q.q.stickyErr -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite_drop.go b/vendor/github.com/go-pg/pg/v10/orm/composite_drop.go deleted file mode 100644 index 2a169b07a..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/composite_drop.go +++ /dev/null @@ -1,70 +0,0 @@ -package orm - -type DropCompositeOptions struct { - IfExists bool - Cascade bool -} - -type DropCompositeQuery struct { - q *Query - opt *DropCompositeOptions -} - -var ( - _ QueryAppender = (*DropCompositeQuery)(nil) - _ QueryCommand = (*DropCompositeQuery)(nil) -) - -func NewDropCompositeQuery(q *Query, opt *DropCompositeOptions) *DropCompositeQuery { - return &DropCompositeQuery{ - q: q, - opt: opt, - } -} - -func (q *DropCompositeQuery) String() string { - b, err := q.AppendQuery(defaultFmter, nil) - if err != nil { - panic(err) - } - return string(b) -} - -func (q *DropCompositeQuery) Operation() QueryOp { - return DropCompositeOp -} - -func (q *DropCompositeQuery) Clone() QueryCommand { - return &DropCompositeQuery{ - q: q.q.Clone(), - opt: q.opt, - } -} - -func (q *DropCompositeQuery) Query() *Query { - return q.q -} - -func (q *DropCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { - return q.AppendQuery(dummyFormatter{}, b) -} - -func (q *DropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { - if q.q.stickyErr != nil { - return nil, q.q.stickyErr - } - if q.q.tableModel == nil { - return nil, errModelNil - } - - b = append(b, "DROP TYPE "...) - if q.opt != nil && q.opt.IfExists { - b = append(b, "IF EXISTS "...) - } - b = append(b, q.q.tableModel.Table().Alias...) - if q.opt != nil && q.opt.Cascade { - b = append(b, " CASCADE"...) - } - - return b, q.q.stickyErr -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/composite_parser.go b/vendor/github.com/go-pg/pg/v10/orm/composite_parser.go deleted file mode 100644 index 29e500444..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/composite_parser.go +++ /dev/null @@ -1,140 +0,0 @@ -package orm - -import ( - "bufio" - "errors" - "fmt" - "io" - - "github.com/go-pg/pg/v10/internal/parser" - "github.com/go-pg/pg/v10/types" -) - -var errEndOfComposite = errors.New("pg: end of composite") - -type compositeParser struct { - p parser.StreamingParser - - stickyErr error -} - -func newCompositeParserErr(err error) *compositeParser { - return &compositeParser{ - stickyErr: err, - } -} - -func newCompositeParser(rd types.Reader) *compositeParser { - p := parser.NewStreamingParser(rd) - err := p.SkipByte('(') - if err != nil { - return newCompositeParserErr(err) - } - return &compositeParser{ - p: p, - } -} - -func (p *compositeParser) NextElem() ([]byte, error) { - if p.stickyErr != nil { - return nil, p.stickyErr - } - - c, err := p.p.ReadByte() - if err != nil { - if err == io.EOF { - return nil, errEndOfComposite - } - return nil, err - } - - switch c { - case '"': - return p.readQuoted() - case ',': - return nil, nil - case ')': - return nil, errEndOfComposite - default: - _ = p.p.UnreadByte() - } - - var b []byte - for { - tmp, err := p.p.ReadSlice(',') - if err == nil { - if b == nil { - b = tmp - } else { - b = append(b, tmp...) - } - b = b[:len(b)-1] - break - } - b = append(b, tmp...) - if err == bufio.ErrBufferFull { - continue - } - if err == io.EOF { - if b[len(b)-1] == ')' { - b = b[:len(b)-1] - break - } - } - return nil, err - } - - if len(b) == 0 { // NULL - return nil, nil - } - return b, nil -} - -func (p *compositeParser) readQuoted() ([]byte, error) { - var b []byte - - c, err := p.p.ReadByte() - if err != nil { - return nil, err - } - - for { - next, err := p.p.ReadByte() - if err != nil { - return nil, err - } - - if c == '\\' || c == '\'' { - if next == c { - b = append(b, c) - c, err = p.p.ReadByte() - if err != nil { - return nil, err - } - } else { - b = append(b, c) - c = next - } - continue - } - - if c == '"' { - switch next { - case '"': - b = append(b, '"') - c, err = p.p.ReadByte() - if err != nil { - return nil, err - } - case ',', ')': - return b, nil - default: - return nil, fmt.Errorf("pg: got %q, wanted ',' or ')'", c) - } - continue - } - - b = append(b, c) - c = next - } -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/count_estimate.go b/vendor/github.com/go-pg/pg/v10/orm/count_estimate.go deleted file mode 100644 index bfa664a72..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/count_estimate.go +++ /dev/null @@ -1,90 +0,0 @@ -package orm - -import ( - "fmt" - - "github.com/go-pg/pg/v10/internal" -) - -// Placeholder that is replaced with count(*). -const placeholder = `'_go_pg_placeholder'` - -// https://wiki.postgresql.org/wiki/Count_estimate -//nolint -var pgCountEstimateFunc = fmt.Sprintf(` -CREATE OR REPLACE FUNCTION _go_pg_count_estimate_v2(query text, threshold int) -RETURNS int AS $$ -DECLARE - rec record; - nrows int; -BEGIN - FOR rec IN EXECUTE 'EXPLAIN ' || query LOOP - nrows := substring(rec."QUERY PLAN" FROM ' rows=(\d+)'); - EXIT WHEN nrows IS NOT NULL; - END LOOP; - - -- Return the estimation if there are too many rows. - IF nrows > threshold THEN - RETURN nrows; - END IF; - - -- Otherwise execute real count query. - query := replace(query, 'SELECT '%s'', 'SELECT count(*)'); - EXECUTE query INTO nrows; - - IF nrows IS NULL THEN - nrows := 0; - END IF; - - RETURN nrows; -END; -$$ LANGUAGE plpgsql; -`, placeholder) - -// CountEstimate uses EXPLAIN to get estimated number of rows returned the query. -// If that number is bigger than the threshold it returns the estimation. -// Otherwise it executes another query using count aggregate function and -// returns the result. -// -// Based on https://wiki.postgresql.org/wiki/Count_estimate -func (q *Query) CountEstimate(threshold int) (int, error) { - if q.stickyErr != nil { - return 0, q.stickyErr - } - - query, err := q.countSelectQuery(placeholder).AppendQuery(q.db.Formatter(), nil) - if err != nil { - return 0, err - } - - for i := 0; i < 3; i++ { - var count int - _, err = q.db.QueryOneContext( - q.ctx, - Scan(&count), - "SELECT _go_pg_count_estimate_v2(?, ?)", - string(query), threshold, - ) - if err != nil { - if pgerr, ok := err.(internal.PGError); ok && pgerr.Field('C') == "42883" { - // undefined_function - err = q.createCountEstimateFunc() - if err != nil { - pgerr, ok := err.(internal.PGError) - if !ok || !pgerr.IntegrityViolation() { - return 0, err - } - } - continue - } - } - return count, err - } - - return 0, err -} - -func (q *Query) createCountEstimateFunc() error { - _, err := q.db.ExecContext(q.ctx, pgCountEstimateFunc) - return err -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/delete.go b/vendor/github.com/go-pg/pg/v10/orm/delete.go deleted file mode 100644 index c54cd10f8..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/delete.go +++ /dev/null @@ -1,158 +0,0 @@ -package orm - -import ( - "reflect" - - "github.com/go-pg/pg/v10/types" -) - -type DeleteQuery struct { - q *Query - placeholder bool -} - -var ( - _ QueryAppender = (*DeleteQuery)(nil) - _ QueryCommand = (*DeleteQuery)(nil) -) - -func NewDeleteQuery(q *Query) *DeleteQuery { - return &DeleteQuery{ - q: q, - } -} - -func (q *DeleteQuery) String() string { - b, err := q.AppendQuery(defaultFmter, nil) - if err != nil { - panic(err) - } - return string(b) -} - -func (q *DeleteQuery) Operation() QueryOp { - return DeleteOp -} - -func (q *DeleteQuery) Clone() QueryCommand { - return &DeleteQuery{ - q: q.q.Clone(), - placeholder: q.placeholder, - } -} - -func (q *DeleteQuery) Query() *Query { - return q.q -} - -func (q *DeleteQuery) AppendTemplate(b []byte) ([]byte, error) { - cp := q.Clone().(*DeleteQuery) - cp.placeholder = true - return cp.AppendQuery(dummyFormatter{}, b) -} - -func (q *DeleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { - if q.q.stickyErr != nil { - return nil, q.q.stickyErr - } - - if len(q.q.with) > 0 { - b, err = q.q.appendWith(fmter, b) - if err != nil { - return nil, err - } - } - - b = append(b, "DELETE FROM "...) - b, err = q.q.appendFirstTableWithAlias(fmter, b) - if err != nil { - return nil, err - } - - if q.q.hasMultiTables() { - b = append(b, " USING "...) - b, err = q.q.appendOtherTables(fmter, b) - if err != nil { - return nil, err - } - } - - b = append(b, " WHERE "...) - value := q.q.tableModel.Value() - - if q.q.isSliceModelWithData() { - if len(q.q.where) > 0 { - b, err = q.q.appendWhere(fmter, b) - if err != nil { - return nil, err - } - } else { - table := q.q.tableModel.Table() - err = table.checkPKs() - if err != nil { - return nil, err - } - - b = appendColumnAndSliceValue(fmter, b, value, table.Alias, table.PKs) - } - } else { - b, err = q.q.mustAppendWhere(fmter, b) - if err != nil { - return nil, err - } - } - - if len(q.q.returning) > 0 { - b, err = q.q.appendReturning(fmter, b) - if err != nil { - return nil, err - } - } - - return b, q.q.stickyErr -} - -func appendColumnAndSliceValue( - fmter QueryFormatter, b []byte, slice reflect.Value, alias types.Safe, fields []*Field, -) []byte { - if len(fields) > 1 { - b = append(b, '(') - } - b = appendColumns(b, alias, fields) - if len(fields) > 1 { - b = append(b, ')') - } - - b = append(b, " IN ("...) - - isPlaceholder := isTemplateFormatter(fmter) - sliceLen := slice.Len() - for i := 0; i < sliceLen; i++ { - if i > 0 { - b = append(b, ", "...) - } - - el := indirect(slice.Index(i)) - - if len(fields) > 1 { - b = append(b, '(') - } - for i, f := range fields { - if i > 0 { - b = append(b, ", "...) - } - if isPlaceholder { - b = append(b, '?') - } else { - b = f.AppendValue(b, el, 1) - } - } - if len(fields) > 1 { - b = append(b, ')') - } - } - - b = append(b, ')') - - return b -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/field.go b/vendor/github.com/go-pg/pg/v10/orm/field.go deleted file mode 100644 index fe9b4abea..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/field.go +++ /dev/null @@ -1,146 +0,0 @@ -package orm - -import ( - "fmt" - "reflect" - - "github.com/go-pg/pg/v10/types" - "github.com/go-pg/zerochecker" -) - -const ( - PrimaryKeyFlag = uint8(1) << iota - ForeignKeyFlag - NotNullFlag - UseZeroFlag - UniqueFlag - ArrayFlag -) - -type Field struct { - Field reflect.StructField - Type reflect.Type - Index []int - - GoName string // struct field name, e.g. Id - SQLName string // SQL name, .e.g. id - Column types.Safe // escaped SQL name, e.g. "id" - SQLType string - UserSQLType string - Default types.Safe - OnDelete string - OnUpdate string - - flags uint8 - - append types.AppenderFunc - scan types.ScannerFunc - - isZero zerochecker.Func -} - -func indexEqual(ind1, ind2 []int) bool { - if len(ind1) != len(ind2) { - return false - } - for i, ind := range ind1 { - if ind != ind2[i] { - return false - } - } - return true -} - -func (f *Field) Clone() *Field { - cp := *f - cp.Index = cp.Index[:len(f.Index):len(f.Index)] - return &cp -} - -func (f *Field) setFlag(flag uint8) { - f.flags |= flag -} - -func (f *Field) hasFlag(flag uint8) bool { - return f.flags&flag != 0 -} - -func (f *Field) Value(strct reflect.Value) reflect.Value { - return fieldByIndexAlloc(strct, f.Index) -} - -func (f *Field) HasZeroValue(strct reflect.Value) bool { - return f.hasZeroValue(strct, f.Index) -} - -func (f *Field) hasZeroValue(v reflect.Value, index []int) bool { - for _, idx := range index { - if v.Kind() == reflect.Ptr { - if v.IsNil() { - return true - } - v = v.Elem() - } - v = v.Field(idx) - } - return f.isZero(v) -} - -func (f *Field) NullZero() bool { - return !f.hasFlag(UseZeroFlag) -} - -func (f *Field) AppendValue(b []byte, strct reflect.Value, quote int) []byte { - fv, ok := fieldByIndex(strct, f.Index) - if !ok { - return types.AppendNull(b, quote) - } - - if f.NullZero() && f.isZero(fv) { - return types.AppendNull(b, quote) - } - if f.append == nil { - panic(fmt.Errorf("pg: AppendValue(unsupported %s)", fv.Type())) - } - return f.append(b, fv, quote) -} - -func (f *Field) ScanValue(strct reflect.Value, rd types.Reader, n int) error { - if f.scan == nil { - return fmt.Errorf("pg: ScanValue(unsupported %s)", f.Type) - } - - var fv reflect.Value - if n == -1 { - var ok bool - fv, ok = fieldByIndex(strct, f.Index) - if !ok { - return nil - } - } else { - fv = fieldByIndexAlloc(strct, f.Index) - } - - return f.scan(fv, rd, n) -} - -type Method struct { - Index int - - flags int8 - - appender func([]byte, reflect.Value, int) []byte -} - -func (m *Method) Has(flag int8) bool { - return m.flags&flag != 0 -} - -func (m *Method) Value(strct reflect.Value) reflect.Value { - return strct.Method(m.Index).Call(nil)[0] -} - -func (m *Method) AppendValue(dst []byte, strct reflect.Value, quote int) []byte { - mv := m.Value(strct) - return m.appender(dst, mv, quote) -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/format.go b/vendor/github.com/go-pg/pg/v10/orm/format.go deleted file mode 100644 index 9945f6e1d..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/format.go +++ /dev/null @@ -1,333 +0,0 @@ -package orm - -import ( - "bytes" - "fmt" - "sort" - "strconv" - "strings" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/internal/parser" - "github.com/go-pg/pg/v10/types" -) - -var defaultFmter = NewFormatter() - -type queryWithSepAppender interface { - QueryAppender - AppendSep([]byte) []byte -} - -//------------------------------------------------------------------------------ - -type SafeQueryAppender struct { - query string - params []interface{} -} - -var ( - _ QueryAppender = (*SafeQueryAppender)(nil) - _ types.ValueAppender = (*SafeQueryAppender)(nil) -) - -//nolint -func SafeQuery(query string, params ...interface{}) *SafeQueryAppender { - return &SafeQueryAppender{query, params} -} - -func (q *SafeQueryAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { - return fmter.FormatQuery(b, q.query, q.params...), nil -} - -func (q *SafeQueryAppender) AppendValue(b []byte, quote int) ([]byte, error) { - return q.AppendQuery(defaultFmter, b) -} - -func (q *SafeQueryAppender) Value() types.Safe { - b, err := q.AppendValue(nil, 1) - if err != nil { - return types.Safe(err.Error()) - } - return types.Safe(internal.BytesToString(b)) -} - -//------------------------------------------------------------------------------ - -type condGroupAppender struct { - sep string - cond []queryWithSepAppender -} - -var ( - _ QueryAppender = (*condAppender)(nil) - _ queryWithSepAppender = (*condAppender)(nil) -) - -func (q *condGroupAppender) AppendSep(b []byte) []byte { - return append(b, q.sep...) -} - -func (q *condGroupAppender) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { - b = append(b, '(') - for i, app := range q.cond { - if i > 0 { - b = app.AppendSep(b) - } - b, err = app.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - b = append(b, ')') - return b, nil -} - -//------------------------------------------------------------------------------ - -type condAppender struct { - sep string - cond string - params []interface{} -} - -var ( - _ QueryAppender = (*condAppender)(nil) - _ queryWithSepAppender = (*condAppender)(nil) -) - -func (q *condAppender) AppendSep(b []byte) []byte { - return append(b, q.sep...) -} - -func (q *condAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { - b = append(b, '(') - b = fmter.FormatQuery(b, q.cond, q.params...) - b = append(b, ')') - return b, nil -} - -//------------------------------------------------------------------------------ - -type fieldAppender struct { - field string -} - -var _ QueryAppender = (*fieldAppender)(nil) - -func (a fieldAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { - return types.AppendIdent(b, a.field, 1), nil -} - -//------------------------------------------------------------------------------ - -type dummyFormatter struct{} - -func (f dummyFormatter) FormatQuery(b []byte, query string, params ...interface{}) []byte { - return append(b, query...) -} - -func isTemplateFormatter(fmter QueryFormatter) bool { - _, ok := fmter.(dummyFormatter) - return ok -} - -//------------------------------------------------------------------------------ - -type QueryFormatter interface { - FormatQuery(b []byte, query string, params ...interface{}) []byte -} - -type Formatter struct { - namedParams map[string]interface{} - model TableModel -} - -var _ QueryFormatter = (*Formatter)(nil) - -func NewFormatter() *Formatter { - return new(Formatter) -} - -func (f *Formatter) String() string { - if len(f.namedParams) == 0 { - return "" - } - - keys := make([]string, len(f.namedParams)) - index := 0 - for k := range f.namedParams { - keys[index] = k - index++ - } - - sort.Strings(keys) - - ss := make([]string, len(keys)) - for i, k := range keys { - ss[i] = fmt.Sprintf("%s=%v", k, f.namedParams[k]) - } - return " " + strings.Join(ss, " ") -} - -func (f *Formatter) clone() *Formatter { - cp := NewFormatter() - - cp.model = f.model - if len(f.namedParams) > 0 { - cp.namedParams = make(map[string]interface{}, len(f.namedParams)) - } - for param, value := range f.namedParams { - cp.setParam(param, value) - } - - return cp -} - -func (f *Formatter) WithTableModel(model TableModel) *Formatter { - cp := f.clone() - cp.model = model - return cp -} - -func (f *Formatter) WithModel(model interface{}) *Formatter { - switch model := model.(type) { - case TableModel: - return f.WithTableModel(model) - case *Query: - return f.WithTableModel(model.tableModel) - case QueryCommand: - return f.WithTableModel(model.Query().tableModel) - default: - panic(fmt.Errorf("pg: unsupported model %T", model)) - } -} - -func (f *Formatter) setParam(param string, value interface{}) { - if f.namedParams == nil { - f.namedParams = make(map[string]interface{}) - } - f.namedParams[param] = value -} - -func (f *Formatter) WithParam(param string, value interface{}) *Formatter { - cp := f.clone() - cp.setParam(param, value) - return cp -} - -func (f *Formatter) Param(param string) interface{} { - return f.namedParams[param] -} - -func (f *Formatter) hasParams() bool { - return len(f.namedParams) > 0 || f.model != nil -} - -func (f *Formatter) FormatQueryBytes(dst, query []byte, params ...interface{}) []byte { - if (params == nil && !f.hasParams()) || bytes.IndexByte(query, '?') == -1 { - return append(dst, query...) - } - return f.append(dst, parser.New(query), params) -} - -func (f *Formatter) FormatQuery(dst []byte, query string, params ...interface{}) []byte { - if (params == nil && !f.hasParams()) || strings.IndexByte(query, '?') == -1 { - return append(dst, query...) - } - return f.append(dst, parser.NewString(query), params) -} - -func (f *Formatter) append(dst []byte, p *parser.Parser, params []interface{}) []byte { - var paramsIndex int - var namedParamsOnce bool - var tableParams *tableParams - - for p.Valid() { - b, ok := p.ReadSep('?') - if !ok { - dst = append(dst, b...) - continue - } - if len(b) > 0 && b[len(b)-1] == '\\' { - dst = append(dst, b[:len(b)-1]...) - dst = append(dst, '?') - continue - } - dst = append(dst, b...) - - id, numeric := p.ReadIdentifier() - if id != "" { - if numeric { - idx, err := strconv.Atoi(id) - if err != nil { - goto restore_param - } - - if idx >= len(params) { - goto restore_param - } - - dst = f.appendParam(dst, params[idx]) - continue - } - - if f.namedParams != nil { - param, paramOK := f.namedParams[id] - if paramOK { - dst = f.appendParam(dst, param) - continue - } - } - - if !namedParamsOnce && len(params) > 0 { - namedParamsOnce = true - tableParams, _ = newTableParams(params[len(params)-1]) - } - - if tableParams != nil { - dst, ok = tableParams.AppendParam(f, dst, id) - if ok { - continue - } - } - - if f.model != nil { - dst, ok = f.model.AppendParam(f, dst, id) - if ok { - continue - } - } - - restore_param: - dst = append(dst, '?') - dst = append(dst, id...) - continue - } - - if paramsIndex >= len(params) { - dst = append(dst, '?') - continue - } - - param := params[paramsIndex] - paramsIndex++ - - dst = f.appendParam(dst, param) - } - - return dst -} - -func (f *Formatter) appendParam(b []byte, param interface{}) []byte { - switch param := param.(type) { - case QueryAppender: - bb, err := param.AppendQuery(f, b) - if err != nil { - return types.AppendError(b, err) - } - return bb - default: - return types.Append(b, param, 1) - } -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/hook.go b/vendor/github.com/go-pg/pg/v10/orm/hook.go deleted file mode 100644 index 78bd10310..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/hook.go +++ /dev/null @@ -1,248 +0,0 @@ -package orm - -import ( - "context" - "reflect" -) - -type hookStubs struct{} - -var ( - _ AfterScanHook = (*hookStubs)(nil) - _ AfterSelectHook = (*hookStubs)(nil) - _ BeforeInsertHook = (*hookStubs)(nil) - _ AfterInsertHook = (*hookStubs)(nil) - _ BeforeUpdateHook = (*hookStubs)(nil) - _ AfterUpdateHook = (*hookStubs)(nil) - _ BeforeDeleteHook = (*hookStubs)(nil) - _ AfterDeleteHook = (*hookStubs)(nil) -) - -func (hookStubs) AfterScan(ctx context.Context) error { - return nil -} - -func (hookStubs) AfterSelect(ctx context.Context) error { - return nil -} - -func (hookStubs) BeforeInsert(ctx context.Context) (context.Context, error) { - return ctx, nil -} - -func (hookStubs) AfterInsert(ctx context.Context) error { - return nil -} - -func (hookStubs) BeforeUpdate(ctx context.Context) (context.Context, error) { - return ctx, nil -} - -func (hookStubs) AfterUpdate(ctx context.Context) error { - return nil -} - -func (hookStubs) BeforeDelete(ctx context.Context) (context.Context, error) { - return ctx, nil -} - -func (hookStubs) AfterDelete(ctx context.Context) error { - return nil -} - -func callHookSlice( - ctx context.Context, - slice reflect.Value, - ptr bool, - hook func(context.Context, reflect.Value) (context.Context, error), -) (context.Context, error) { - var firstErr error - sliceLen := slice.Len() - for i := 0; i < sliceLen; i++ { - v := slice.Index(i) - if !ptr { - v = v.Addr() - } - - var err error - ctx, err = hook(ctx, v) - if err != nil && firstErr == nil { - firstErr = err - } - } - return ctx, firstErr -} - -func callHookSlice2( - ctx context.Context, - slice reflect.Value, - ptr bool, - hook func(context.Context, reflect.Value) error, -) error { - var firstErr error - if slice.IsValid() { - sliceLen := slice.Len() - for i := 0; i < sliceLen; i++ { - v := slice.Index(i) - if !ptr { - v = v.Addr() - } - - err := hook(ctx, v) - if err != nil && firstErr == nil { - firstErr = err - } - } - } - return firstErr -} - -//------------------------------------------------------------------------------ - -type BeforeScanHook interface { - BeforeScan(context.Context) error -} - -var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem() - -func callBeforeScanHook(ctx context.Context, v reflect.Value) error { - return v.Interface().(BeforeScanHook).BeforeScan(ctx) -} - -//------------------------------------------------------------------------------ - -type AfterScanHook interface { - AfterScan(context.Context) error -} - -var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem() - -func callAfterScanHook(ctx context.Context, v reflect.Value) error { - return v.Interface().(AfterScanHook).AfterScan(ctx) -} - -//------------------------------------------------------------------------------ - -type AfterSelectHook interface { - AfterSelect(context.Context) error -} - -var afterSelectHookType = reflect.TypeOf((*AfterSelectHook)(nil)).Elem() - -func callAfterSelectHook(ctx context.Context, v reflect.Value) error { - return v.Interface().(AfterSelectHook).AfterSelect(ctx) -} - -func callAfterSelectHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) error { - return callHookSlice2(ctx, slice, ptr, callAfterSelectHook) -} - -//------------------------------------------------------------------------------ - -type BeforeInsertHook interface { - BeforeInsert(context.Context) (context.Context, error) -} - -var beforeInsertHookType = reflect.TypeOf((*BeforeInsertHook)(nil)).Elem() - -func callBeforeInsertHook(ctx context.Context, v reflect.Value) (context.Context, error) { - return v.Interface().(BeforeInsertHook).BeforeInsert(ctx) -} - -func callBeforeInsertHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) (context.Context, error) { - return callHookSlice(ctx, slice, ptr, callBeforeInsertHook) -} - -//------------------------------------------------------------------------------ - -type AfterInsertHook interface { - AfterInsert(context.Context) error -} - -var afterInsertHookType = reflect.TypeOf((*AfterInsertHook)(nil)).Elem() - -func callAfterInsertHook(ctx context.Context, v reflect.Value) error { - return v.Interface().(AfterInsertHook).AfterInsert(ctx) -} - -func callAfterInsertHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) error { - return callHookSlice2(ctx, slice, ptr, callAfterInsertHook) -} - -//------------------------------------------------------------------------------ - -type BeforeUpdateHook interface { - BeforeUpdate(context.Context) (context.Context, error) -} - -var beforeUpdateHookType = reflect.TypeOf((*BeforeUpdateHook)(nil)).Elem() - -func callBeforeUpdateHook(ctx context.Context, v reflect.Value) (context.Context, error) { - return v.Interface().(BeforeUpdateHook).BeforeUpdate(ctx) -} - -func callBeforeUpdateHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) (context.Context, error) { - return callHookSlice(ctx, slice, ptr, callBeforeUpdateHook) -} - -//------------------------------------------------------------------------------ - -type AfterUpdateHook interface { - AfterUpdate(context.Context) error -} - -var afterUpdateHookType = reflect.TypeOf((*AfterUpdateHook)(nil)).Elem() - -func callAfterUpdateHook(ctx context.Context, v reflect.Value) error { - return v.Interface().(AfterUpdateHook).AfterUpdate(ctx) -} - -func callAfterUpdateHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) error { - return callHookSlice2(ctx, slice, ptr, callAfterUpdateHook) -} - -//------------------------------------------------------------------------------ - -type BeforeDeleteHook interface { - BeforeDelete(context.Context) (context.Context, error) -} - -var beforeDeleteHookType = reflect.TypeOf((*BeforeDeleteHook)(nil)).Elem() - -func callBeforeDeleteHook(ctx context.Context, v reflect.Value) (context.Context, error) { - return v.Interface().(BeforeDeleteHook).BeforeDelete(ctx) -} - -func callBeforeDeleteHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) (context.Context, error) { - return callHookSlice(ctx, slice, ptr, callBeforeDeleteHook) -} - -//------------------------------------------------------------------------------ - -type AfterDeleteHook interface { - AfterDelete(context.Context) error -} - -var afterDeleteHookType = reflect.TypeOf((*AfterDeleteHook)(nil)).Elem() - -func callAfterDeleteHook(ctx context.Context, v reflect.Value) error { - return v.Interface().(AfterDeleteHook).AfterDelete(ctx) -} - -func callAfterDeleteHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) error { - return callHookSlice2(ctx, slice, ptr, callAfterDeleteHook) -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/insert.go b/vendor/github.com/go-pg/pg/v10/orm/insert.go deleted file mode 100644 index a7a543576..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/insert.go +++ /dev/null @@ -1,345 +0,0 @@ -package orm - -import ( - "fmt" - "reflect" - "sort" - - "github.com/go-pg/pg/v10/types" -) - -type InsertQuery struct { - q *Query - returningFields []*Field - placeholder bool -} - -var _ QueryCommand = (*InsertQuery)(nil) - -func NewInsertQuery(q *Query) *InsertQuery { - return &InsertQuery{ - q: q, - } -} - -func (q *InsertQuery) String() string { - b, err := q.AppendQuery(defaultFmter, nil) - if err != nil { - panic(err) - } - return string(b) -} - -func (q *InsertQuery) Operation() QueryOp { - return InsertOp -} - -func (q *InsertQuery) Clone() QueryCommand { - return &InsertQuery{ - q: q.q.Clone(), - placeholder: q.placeholder, - } -} - -func (q *InsertQuery) Query() *Query { - return q.q -} - -var _ TemplateAppender = (*InsertQuery)(nil) - -func (q *InsertQuery) AppendTemplate(b []byte) ([]byte, error) { - cp := q.Clone().(*InsertQuery) - cp.placeholder = true - return cp.AppendQuery(dummyFormatter{}, b) -} - -var _ QueryAppender = (*InsertQuery)(nil) - -func (q *InsertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { - if q.q.stickyErr != nil { - return nil, q.q.stickyErr - } - - if len(q.q.with) > 0 { - b, err = q.q.appendWith(fmter, b) - if err != nil { - return nil, err - } - } - - b = append(b, "INSERT INTO "...) - if q.q.onConflict != nil { - b, err = q.q.appendFirstTableWithAlias(fmter, b) - } else { - b, err = q.q.appendFirstTable(fmter, b) - } - if err != nil { - return nil, err - } - - b, err = q.appendColumnsValues(fmter, b) - if err != nil { - return nil, err - } - - if q.q.onConflict != nil { - b = append(b, " ON CONFLICT "...) - b, err = q.q.onConflict.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - - if q.q.onConflictDoUpdate() { - if len(q.q.set) > 0 { - b, err = q.q.appendSet(fmter, b) - if err != nil { - return nil, err - } - } else { - fields, err := q.q.getDataFields() - if err != nil { - return nil, err - } - - if len(fields) == 0 { - fields = q.q.tableModel.Table().DataFields - } - - b = q.appendSetExcluded(b, fields) - } - - if len(q.q.updWhere) > 0 { - b = append(b, " WHERE "...) - b, err = q.q.appendUpdWhere(fmter, b) - if err != nil { - return nil, err - } - } - } - } - - if len(q.q.returning) > 0 { - b, err = q.q.appendReturning(fmter, b) - if err != nil { - return nil, err - } - } else if len(q.returningFields) > 0 { - b = appendReturningFields(b, q.returningFields) - } - - return b, q.q.stickyErr -} - -func (q *InsertQuery) appendColumnsValues(fmter QueryFormatter, b []byte) (_ []byte, err error) { - if q.q.hasMultiTables() { - if q.q.columns != nil { - b = append(b, " ("...) - b, err = q.q.appendColumns(fmter, b) - if err != nil { - return nil, err - } - b = append(b, ")"...) - } - - b = append(b, " SELECT * FROM "...) - b, err = q.q.appendOtherTables(fmter, b) - if err != nil { - return nil, err - } - - return b, nil - } - - if m, ok := q.q.model.(*mapModel); ok { - return q.appendMapColumnsValues(b, m.m), nil - } - - if !q.q.hasTableModel() { - return nil, errModelNil - } - - fields, err := q.q.getFields() - if err != nil { - return nil, err - } - - if len(fields) == 0 { - fields = q.q.tableModel.Table().Fields - } - value := q.q.tableModel.Value() - - b = append(b, " ("...) - b = q.appendColumns(b, fields) - b = append(b, ") VALUES ("...) - if m, ok := q.q.tableModel.(*sliceTableModel); ok { - if m.sliceLen == 0 { - err = fmt.Errorf("pg: can't bulk-insert empty slice %s", value.Type()) - return nil, err - } - b, err = q.appendSliceValues(fmter, b, fields, value) - if err != nil { - return nil, err - } - } else { - b, err = q.appendValues(fmter, b, fields, value) - if err != nil { - return nil, err - } - } - b = append(b, ")"...) - - return b, nil -} - -func (q *InsertQuery) appendMapColumnsValues(b []byte, m map[string]interface{}) []byte { - keys := make([]string, 0, len(m)) - - for k := range m { - keys = append(keys, k) - } - sort.Strings(keys) - - b = append(b, " ("...) - - for i, k := range keys { - if i > 0 { - b = append(b, ", "...) - } - b = types.AppendIdent(b, k, 1) - } - - b = append(b, ") VALUES ("...) - - for i, k := range keys { - if i > 0 { - b = append(b, ", "...) - } - if q.placeholder { - b = append(b, '?') - } else { - b = types.Append(b, m[k], 1) - } - } - - b = append(b, ")"...) - - return b -} - -func (q *InsertQuery) appendValues( - fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value, -) (_ []byte, err error) { - for i, f := range fields { - if i > 0 { - b = append(b, ", "...) - } - - app, ok := q.q.modelValues[f.SQLName] - if ok { - b, err = app.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - q.addReturningField(f) - continue - } - - switch { - case q.placeholder: - b = append(b, '?') - case (f.Default != "" || f.NullZero()) && f.HasZeroValue(strct): - b = append(b, "DEFAULT"...) - q.addReturningField(f) - default: - b = f.AppendValue(b, strct, 1) - } - } - - for i, v := range q.q.extraValues { - if i > 0 || len(fields) > 0 { - b = append(b, ", "...) - } - - b, err = v.value.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - - return b, nil -} - -func (q *InsertQuery) appendSliceValues( - fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value, -) (_ []byte, err error) { - if q.placeholder { - return q.appendValues(fmter, b, fields, reflect.Value{}) - } - - sliceLen := slice.Len() - for i := 0; i < sliceLen; i++ { - if i > 0 { - b = append(b, "), ("...) - } - el := indirect(slice.Index(i)) - b, err = q.appendValues(fmter, b, fields, el) - if err != nil { - return nil, err - } - } - - for i, v := range q.q.extraValues { - if i > 0 || len(fields) > 0 { - b = append(b, ", "...) - } - - b, err = v.value.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - - return b, nil -} - -func (q *InsertQuery) addReturningField(field *Field) { - if len(q.q.returning) > 0 { - return - } - for _, f := range q.returningFields { - if f == field { - return - } - } - q.returningFields = append(q.returningFields, field) -} - -func (q *InsertQuery) appendSetExcluded(b []byte, fields []*Field) []byte { - b = append(b, " SET "...) - for i, f := range fields { - if i > 0 { - b = append(b, ", "...) - } - b = append(b, f.Column...) - b = append(b, " = EXCLUDED."...) - b = append(b, f.Column...) - } - return b -} - -func (q *InsertQuery) appendColumns(b []byte, fields []*Field) []byte { - b = appendColumns(b, "", fields) - for i, v := range q.q.extraValues { - if i > 0 || len(fields) > 0 { - b = append(b, ", "...) - } - b = types.AppendIdent(b, v.column, 1) - } - return b -} - -func appendReturningFields(b []byte, fields []*Field) []byte { - b = append(b, " RETURNING "...) - b = appendColumns(b, "", fields) - return b -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/join.go b/vendor/github.com/go-pg/pg/v10/orm/join.go deleted file mode 100644 index 2b64ba1b8..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/join.go +++ /dev/null @@ -1,351 +0,0 @@ -package orm - -import ( - "reflect" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/types" -) - -type join struct { - Parent *join - BaseModel TableModel - JoinModel TableModel - Rel *Relation - - ApplyQuery func(*Query) (*Query, error) - Columns []string - on []*condAppender -} - -func (j *join) AppendOn(app *condAppender) { - j.on = append(j.on, app) -} - -func (j *join) Select(fmter QueryFormatter, q *Query) error { - switch j.Rel.Type { - case HasManyRelation: - return j.selectMany(fmter, q) - case Many2ManyRelation: - return j.selectM2M(fmter, q) - } - panic("not reached") -} - -func (j *join) selectMany(_ QueryFormatter, q *Query) error { - q, err := j.manyQuery(q) - if err != nil { - return err - } - if q == nil { - return nil - } - return q.Select() -} - -func (j *join) manyQuery(q *Query) (*Query, error) { - manyModel := newManyModel(j) - if manyModel == nil { - return nil, nil - } - - q = q.Model(manyModel) - if j.ApplyQuery != nil { - var err error - q, err = j.ApplyQuery(q) - if err != nil { - return nil, err - } - } - - if len(q.columns) == 0 { - q.columns = append(q.columns, &hasManyColumnsAppender{j}) - } - - baseTable := j.BaseModel.Table() - var where []byte - if len(j.Rel.JoinFKs) > 1 { - where = append(where, '(') - } - where = appendColumns(where, j.JoinModel.Table().Alias, j.Rel.JoinFKs) - if len(j.Rel.JoinFKs) > 1 { - where = append(where, ')') - } - where = append(where, " IN ("...) - where = appendChildValues( - where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Rel.BaseFKs) - where = append(where, ")"...) - q = q.Where(internal.BytesToString(where)) - - if j.Rel.Polymorphic != nil { - q = q.Where(`? IN (?, ?)`, - j.Rel.Polymorphic.Column, - baseTable.ModelName, baseTable.TypeName) - } - - return q, nil -} - -func (j *join) selectM2M(fmter QueryFormatter, q *Query) error { - q, err := j.m2mQuery(fmter, q) - if err != nil { - return err - } - if q == nil { - return nil - } - return q.Select() -} - -func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) { - m2mModel := newM2MModel(j) - if m2mModel == nil { - return nil, nil - } - - q = q.Model(m2mModel) - if j.ApplyQuery != nil { - var err error - q, err = j.ApplyQuery(q) - if err != nil { - return nil, err - } - } - - if len(q.columns) == 0 { - q.columns = append(q.columns, &hasManyColumnsAppender{j}) - } - - index := j.JoinModel.ParentIndex() - baseTable := j.BaseModel.Table() - - //nolint - var join []byte - join = append(join, "JOIN "...) - join = fmter.FormatQuery(join, string(j.Rel.M2MTableName)) - join = append(join, " AS "...) - join = append(join, j.Rel.M2MTableAlias...) - join = append(join, " ON ("...) - for i, col := range j.Rel.M2MBaseFKs { - if i > 0 { - join = append(join, ", "...) - } - join = append(join, j.Rel.M2MTableAlias...) - join = append(join, '.') - join = types.AppendIdent(join, col, 1) - } - join = append(join, ") IN ("...) - join = appendChildValues(join, j.BaseModel.Root(), index, baseTable.PKs) - join = append(join, ")"...) - q = q.Join(internal.BytesToString(join)) - - joinTable := j.JoinModel.Table() - for i, col := range j.Rel.M2MJoinFKs { - pk := joinTable.PKs[i] - q = q.Where("?.? = ?.?", - joinTable.Alias, pk.Column, - j.Rel.M2MTableAlias, types.Ident(col)) - } - - return q, nil -} - -func (j *join) hasParent() bool { - if j.Parent != nil { - switch j.Parent.Rel.Type { - case HasOneRelation, BelongsToRelation: - return true - } - } - return false -} - -func (j *join) appendAlias(b []byte) []byte { - b = append(b, '"') - b = appendAlias(b, j) - b = append(b, '"') - return b -} - -func (j *join) appendAliasColumn(b []byte, column string) []byte { - b = append(b, '"') - b = appendAlias(b, j) - b = append(b, "__"...) - b = append(b, column...) - b = append(b, '"') - return b -} - -func (j *join) appendBaseAlias(b []byte) []byte { - if j.hasParent() { - b = append(b, '"') - b = appendAlias(b, j.Parent) - b = append(b, '"') - return b - } - return append(b, j.BaseModel.Table().Alias...) -} - -func (j *join) appendSoftDelete(b []byte, flags queryFlag) []byte { - b = append(b, '.') - b = append(b, j.JoinModel.Table().SoftDeleteField.Column...) - if hasFlag(flags, deletedFlag) { - b = append(b, " IS NOT NULL"...) - } else { - b = append(b, " IS NULL"...) - } - return b -} - -func appendAlias(b []byte, j *join) []byte { - if j.hasParent() { - b = appendAlias(b, j.Parent) - b = append(b, "__"...) - } - b = append(b, j.Rel.Field.SQLName...) - return b -} - -func (j *join) appendHasOneColumns(b []byte) []byte { - if j.Columns == nil { - for i, f := range j.JoinModel.Table().Fields { - if i > 0 { - b = append(b, ", "...) - } - b = j.appendAlias(b) - b = append(b, '.') - b = append(b, f.Column...) - b = append(b, " AS "...) - b = j.appendAliasColumn(b, f.SQLName) - } - return b - } - - for i, column := range j.Columns { - if i > 0 { - b = append(b, ", "...) - } - b = j.appendAlias(b) - b = append(b, '.') - b = types.AppendIdent(b, column, 1) - b = append(b, " AS "...) - b = j.appendAliasColumn(b, column) - } - - return b -} - -func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []byte, err error) { - isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag) - - b = append(b, "LEFT JOIN "...) - b = fmter.FormatQuery(b, string(j.JoinModel.Table().SQLNameForSelects)) - b = append(b, " AS "...) - b = j.appendAlias(b) - - b = append(b, " ON "...) - - if isSoftDelete { - b = append(b, '(') - } - - if len(j.Rel.BaseFKs) > 1 { - b = append(b, '(') - } - for i, baseFK := range j.Rel.BaseFKs { - if i > 0 { - b = append(b, " AND "...) - } - b = j.appendAlias(b) - b = append(b, '.') - b = append(b, j.Rel.JoinFKs[i].Column...) - b = append(b, " = "...) - b = j.appendBaseAlias(b) - b = append(b, '.') - b = append(b, baseFK.Column...) - } - if len(j.Rel.BaseFKs) > 1 { - b = append(b, ')') - } - - for _, on := range j.on { - b = on.AppendSep(b) - b, err = on.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - - if isSoftDelete { - b = append(b, ')') - } - - if isSoftDelete { - b = append(b, " AND "...) - b = j.appendAlias(b) - b = j.appendSoftDelete(b, q.flags) - } - - return b, nil -} - -type hasManyColumnsAppender struct { - *join -} - -var _ QueryAppender = (*hasManyColumnsAppender)(nil) - -func (q *hasManyColumnsAppender) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { - if q.Rel.M2MTableAlias != "" { - b = append(b, q.Rel.M2MTableAlias...) - b = append(b, ".*, "...) - } - - joinTable := q.JoinModel.Table() - - if q.Columns != nil { - for i, column := range q.Columns { - if i > 0 { - b = append(b, ", "...) - } - b = append(b, joinTable.Alias...) - b = append(b, '.') - b = types.AppendIdent(b, column, 1) - } - return b, nil - } - - b = appendColumns(b, joinTable.Alias, joinTable.Fields) - return b, nil -} - -func appendChildValues(b []byte, v reflect.Value, index []int, fields []*Field) []byte { - seen := make(map[string]struct{}) - walk(v, index, func(v reflect.Value) { - start := len(b) - - if len(fields) > 1 { - b = append(b, '(') - } - for i, f := range fields { - if i > 0 { - b = append(b, ", "...) - } - b = f.AppendValue(b, v, 1) - } - if len(fields) > 1 { - b = append(b, ')') - } - b = append(b, ", "...) - - if _, ok := seen[string(b[start:])]; ok { - b = b[:start] - } else { - seen[string(b[start:])] = struct{}{} - } - }) - if len(seen) > 0 { - b = b[:len(b)-2] // trim ", " - } - return b -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model.go b/vendor/github.com/go-pg/pg/v10/orm/model.go deleted file mode 100644 index 333a90dd7..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model.go +++ /dev/null @@ -1,150 +0,0 @@ -package orm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" - - "github.com/go-pg/pg/v10/types" -) - -var errModelNil = errors.New("pg: Model(nil)") - -type useQueryOne interface { - useQueryOne() bool -} - -type HooklessModel interface { - // Init is responsible to initialize/reset model state. - // It is called only once no matter how many rows were returned. - Init() error - - // NextColumnScanner returns a ColumnScanner that is used to scan columns - // from the current row. It is called once for every row. - NextColumnScanner() ColumnScanner - - // AddColumnScanner adds the ColumnScanner to the model. - AddColumnScanner(ColumnScanner) error -} - -type Model interface { - HooklessModel - - AfterScanHook - AfterSelectHook - - BeforeInsertHook - AfterInsertHook - - BeforeUpdateHook - AfterUpdateHook - - BeforeDeleteHook - AfterDeleteHook -} - -func NewModel(value interface{}) (Model, error) { - return newModel(value, false) -} - -func newScanModel(values []interface{}) (Model, error) { - if len(values) > 1 { - return Scan(values...), nil - } - return newModel(values[0], true) -} - -func newModel(value interface{}, scan bool) (Model, error) { - switch value := value.(type) { - case Model: - return value, nil - case HooklessModel: - return newModelWithHookStubs(value), nil - case types.ValueScanner, sql.Scanner: - if !scan { - return nil, fmt.Errorf("pg: Model(unsupported %T)", value) - } - return Scan(value), nil - } - - v := reflect.ValueOf(value) - if !v.IsValid() { - return nil, errModelNil - } - if v.Kind() != reflect.Ptr { - return nil, fmt.Errorf("pg: Model(non-pointer %T)", value) - } - - if v.IsNil() { - typ := v.Type().Elem() - if typ.Kind() == reflect.Struct { - return newStructTableModel(GetTable(typ)), nil - } - return nil, errModelNil - } - - v = v.Elem() - - if v.Kind() == reflect.Interface { - if !v.IsNil() { - v = v.Elem() - if v.Kind() != reflect.Ptr { - return nil, fmt.Errorf("pg: Model(non-pointer %s)", v.Type().String()) - } - } - } - - switch v.Kind() { - case reflect.Struct: - if v.Type() != timeType { - return newStructTableModelValue(v), nil - } - case reflect.Slice: - elemType := sliceElemType(v) - switch elemType.Kind() { - case reflect.Struct: - if elemType != timeType { - return newSliceTableModel(v, elemType), nil - } - case reflect.Map: - if err := validMap(elemType); err != nil { - return nil, err - } - slicePtr := v.Addr().Interface().(*[]map[string]interface{}) - return newMapSliceModel(slicePtr), nil - } - return newSliceModel(v, elemType), nil - case reflect.Map: - typ := v.Type() - if err := validMap(typ); err != nil { - return nil, err - } - mapPtr := v.Addr().Interface().(*map[string]interface{}) - return newMapModel(mapPtr), nil - } - - if !scan { - return nil, fmt.Errorf("pg: Model(unsupported %T)", value) - } - return Scan(value), nil -} - -type modelWithHookStubs struct { - hookStubs - HooklessModel -} - -func newModelWithHookStubs(m HooklessModel) Model { - return modelWithHookStubs{ - HooklessModel: m, - } -} - -func validMap(typ reflect.Type) error { - if typ.Key().Kind() != reflect.String || typ.Elem().Kind() != reflect.Interface { - return fmt.Errorf("pg: Model(unsupported %s, expected *map[string]interface{})", - typ.String()) - } - return nil -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_discard.go b/vendor/github.com/go-pg/pg/v10/orm/model_discard.go deleted file mode 100644 index 92e5c566c..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model_discard.go +++ /dev/null @@ -1,27 +0,0 @@ -package orm - -import ( - "github.com/go-pg/pg/v10/types" -) - -type Discard struct { - hookStubs -} - -var _ Model = (*Discard)(nil) - -func (Discard) Init() error { - return nil -} - -func (m Discard) NextColumnScanner() ColumnScanner { - return m -} - -func (m Discard) AddColumnScanner(ColumnScanner) error { - return nil -} - -func (m Discard) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { - return nil -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_func.go b/vendor/github.com/go-pg/pg/v10/orm/model_func.go deleted file mode 100644 index 8427bdea2..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model_func.go +++ /dev/null @@ -1,89 +0,0 @@ -package orm - -import ( - "fmt" - "reflect" -) - -var errorType = reflect.TypeOf((*error)(nil)).Elem() - -type funcModel struct { - Model - fnv reflect.Value - fnIn []reflect.Value -} - -var _ Model = (*funcModel)(nil) - -func newFuncModel(fn interface{}) *funcModel { - m := &funcModel{ - fnv: reflect.ValueOf(fn), - } - - fnt := m.fnv.Type() - if fnt.Kind() != reflect.Func { - panic(fmt.Errorf("ForEach expects a %s, got a %s", - reflect.Func, fnt.Kind())) - } - - if fnt.NumIn() < 1 { - panic(fmt.Errorf("ForEach expects at least 1 arg, got %d", fnt.NumIn())) - } - - if fnt.NumOut() != 1 { - panic(fmt.Errorf("ForEach must return 1 error value, got %d", fnt.NumOut())) - } - if fnt.Out(0) != errorType { - panic(fmt.Errorf("ForEach must return an error, got %T", fnt.Out(0))) - } - - if fnt.NumIn() > 1 { - initFuncModelScan(m, fnt) - return m - } - - t0 := fnt.In(0) - var v0 reflect.Value - if t0.Kind() == reflect.Ptr { - t0 = t0.Elem() - v0 = reflect.New(t0) - } else { - v0 = reflect.New(t0).Elem() - } - - m.fnIn = []reflect.Value{v0} - - model, ok := v0.Interface().(Model) - if ok { - m.Model = model - return m - } - - if v0.Kind() == reflect.Ptr { - v0 = v0.Elem() - } - if v0.Kind() != reflect.Struct { - panic(fmt.Errorf("ForEach accepts a %s, got %s", - reflect.Struct, v0.Kind())) - } - m.Model = newStructTableModelValue(v0) - - return m -} - -func initFuncModelScan(m *funcModel, fnt reflect.Type) { - m.fnIn = make([]reflect.Value, fnt.NumIn()) - for i := 0; i < fnt.NumIn(); i++ { - m.fnIn[i] = reflect.New(fnt.In(i)).Elem() - } - m.Model = scanReflectValues(m.fnIn) -} - -func (m *funcModel) AddColumnScanner(_ ColumnScanner) error { - out := m.fnv.Call(m.fnIn) - errv := out[0] - if !errv.IsNil() { - return errv.Interface().(error) - } - return nil -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_map.go b/vendor/github.com/go-pg/pg/v10/orm/model_map.go deleted file mode 100644 index 24533d43c..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model_map.go +++ /dev/null @@ -1,53 +0,0 @@ -package orm - -import ( - "github.com/go-pg/pg/v10/types" -) - -type mapModel struct { - hookStubs - ptr *map[string]interface{} - m map[string]interface{} -} - -var _ Model = (*mapModel)(nil) - -func newMapModel(ptr *map[string]interface{}) *mapModel { - model := &mapModel{ - ptr: ptr, - } - if ptr != nil { - model.m = *ptr - } - return model -} - -func (m *mapModel) Init() error { - return nil -} - -func (m *mapModel) NextColumnScanner() ColumnScanner { - if m.m == nil { - m.m = make(map[string]interface{}) - *m.ptr = m.m - } - return m -} - -func (m mapModel) AddColumnScanner(ColumnScanner) error { - return nil -} - -func (m *mapModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { - val, err := types.ReadColumnValue(col, rd, n) - if err != nil { - return err - } - - m.m[col.Name] = val - return nil -} - -func (mapModel) useQueryOne() bool { - return true -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_map_slice.go b/vendor/github.com/go-pg/pg/v10/orm/model_map_slice.go deleted file mode 100644 index ea14c9b6b..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model_map_slice.go +++ /dev/null @@ -1,45 +0,0 @@ -package orm - -type mapSliceModel struct { - mapModel - slice *[]map[string]interface{} -} - -var _ Model = (*mapSliceModel)(nil) - -func newMapSliceModel(ptr *[]map[string]interface{}) *mapSliceModel { - return &mapSliceModel{ - slice: ptr, - } -} - -func (m *mapSliceModel) Init() error { - slice := *m.slice - if len(slice) > 0 { - *m.slice = slice[:0] - } - return nil -} - -func (m *mapSliceModel) NextColumnScanner() ColumnScanner { - slice := *m.slice - if len(slice) == cap(slice) { - m.mapModel.m = make(map[string]interface{}) - *m.slice = append(slice, m.mapModel.m) //nolint:gocritic - return m - } - - slice = slice[:len(slice)+1] - el := slice[len(slice)-1] - if el != nil { - m.mapModel.m = el - } else { - el = make(map[string]interface{}) - slice[len(slice)-1] = el - m.mapModel.m = el - } - *m.slice = slice - return m -} - -func (mapSliceModel) useQueryOne() {} //nolint:unused diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_scan.go b/vendor/github.com/go-pg/pg/v10/orm/model_scan.go deleted file mode 100644 index 08f66beba..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model_scan.go +++ /dev/null @@ -1,69 +0,0 @@ -package orm - -import ( - "fmt" - "reflect" - - "github.com/go-pg/pg/v10/types" -) - -type scanValuesModel struct { - Discard - values []interface{} -} - -var _ Model = scanValuesModel{} - -//nolint -func Scan(values ...interface{}) scanValuesModel { - return scanValuesModel{ - values: values, - } -} - -func (scanValuesModel) useQueryOne() bool { - return true -} - -func (m scanValuesModel) NextColumnScanner() ColumnScanner { - return m -} - -func (m scanValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { - if int(col.Index) >= len(m.values) { - return fmt.Errorf("pg: no Scan var for column index=%d name=%q", - col.Index, col.Name) - } - return types.Scan(m.values[col.Index], rd, n) -} - -//------------------------------------------------------------------------------ - -type scanReflectValuesModel struct { - Discard - values []reflect.Value -} - -var _ Model = scanReflectValuesModel{} - -func scanReflectValues(values []reflect.Value) scanReflectValuesModel { - return scanReflectValuesModel{ - values: values, - } -} - -func (scanReflectValuesModel) useQueryOne() bool { - return true -} - -func (m scanReflectValuesModel) NextColumnScanner() ColumnScanner { - return m -} - -func (m scanReflectValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { - if int(col.Index) >= len(m.values) { - return fmt.Errorf("pg: no Scan var for column index=%d name=%q", - col.Index, col.Name) - } - return types.ScanValue(m.values[col.Index], rd, n) -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_slice.go b/vendor/github.com/go-pg/pg/v10/orm/model_slice.go deleted file mode 100644 index 1e163629e..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model_slice.go +++ /dev/null @@ -1,43 +0,0 @@ -package orm - -import ( - "reflect" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/types" -) - -type sliceModel struct { - Discard - slice reflect.Value - nextElem func() reflect.Value - scan func(reflect.Value, types.Reader, int) error -} - -var _ Model = (*sliceModel)(nil) - -func newSliceModel(slice reflect.Value, elemType reflect.Type) *sliceModel { - return &sliceModel{ - slice: slice, - scan: types.Scanner(elemType), - } -} - -func (m *sliceModel) Init() error { - if m.slice.IsValid() && m.slice.Len() > 0 { - m.slice.Set(m.slice.Slice(0, 0)) - } - return nil -} - -func (m *sliceModel) NextColumnScanner() ColumnScanner { - return m -} - -func (m *sliceModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { - if m.nextElem == nil { - m.nextElem = internal.MakeSliceNextElemFunc(m.slice) - } - v := m.nextElem() - return m.scan(v, rd, n) -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table.go b/vendor/github.com/go-pg/pg/v10/orm/model_table.go deleted file mode 100644 index afdc15ccc..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model_table.go +++ /dev/null @@ -1,65 +0,0 @@ -package orm - -import ( - "fmt" - "reflect" - - "github.com/go-pg/pg/v10/types" -) - -type TableModel interface { - Model - - IsNil() bool - Table() *Table - Relation() *Relation - AppendParam(QueryFormatter, []byte, string) ([]byte, bool) - - Join(string, func(*Query) (*Query, error)) *join - GetJoin(string) *join - GetJoins() []join - AddJoin(join) *join - - Root() reflect.Value - Index() []int - ParentIndex() []int - Mount(reflect.Value) - Kind() reflect.Kind - Value() reflect.Value - - setSoftDeleteField() error - scanColumn(types.ColumnInfo, types.Reader, int) (bool, error) -} - -func newTableModelIndex(typ reflect.Type, root reflect.Value, index []int, rel *Relation) (TableModel, error) { - typ = typeByIndex(typ, index) - - if typ.Kind() == reflect.Struct { - return &structTableModel{ - table: GetTable(typ), - rel: rel, - - root: root, - index: index, - }, nil - } - - if typ.Kind() == reflect.Slice { - structType := indirectType(typ.Elem()) - if structType.Kind() == reflect.Struct { - m := sliceTableModel{ - structTableModel: structTableModel{ - table: GetTable(structType), - rel: rel, - - root: root, - index: index, - }, - } - m.init(typ) - return &m, nil - } - } - - return nil, fmt.Errorf("pg: NewModel(%s)", typ) -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_m2m.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_m2m.go deleted file mode 100644 index 83ac73bde..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model_table_m2m.go +++ /dev/null @@ -1,111 +0,0 @@ -package orm - -import ( - "fmt" - "reflect" - - "github.com/go-pg/pg/v10/internal/pool" - "github.com/go-pg/pg/v10/types" -) - -type m2mModel struct { - *sliceTableModel - baseTable *Table - rel *Relation - - buf []byte - dstValues map[string][]reflect.Value - columns map[string]string -} - -var _ TableModel = (*m2mModel)(nil) - -func newM2MModel(j *join) *m2mModel { - baseTable := j.BaseModel.Table() - joinModel := j.JoinModel.(*sliceTableModel) - dstValues := dstValues(joinModel, baseTable.PKs) - if len(dstValues) == 0 { - return nil - } - m := &m2mModel{ - sliceTableModel: joinModel, - baseTable: baseTable, - rel: j.Rel, - - dstValues: dstValues, - columns: make(map[string]string), - } - if !m.sliceOfPtr { - m.strct = reflect.New(m.table.Type).Elem() - } - return m -} - -func (m *m2mModel) NextColumnScanner() ColumnScanner { - if m.sliceOfPtr { - m.strct = reflect.New(m.table.Type).Elem() - } else { - m.strct.Set(m.table.zeroStruct) - } - m.structInited = false - return m -} - -func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error { - buf, err := m.modelIDMap(m.buf[:0]) - if err != nil { - return err - } - m.buf = buf - - dstValues, ok := m.dstValues[string(buf)] - if !ok { - return fmt.Errorf( - "pg: relation=%q does not have base %s with id=%q (check join conditions)", - m.rel.Field.GoName, m.baseTable, buf) - } - - for _, v := range dstValues { - if m.sliceOfPtr { - v.Set(reflect.Append(v, m.strct.Addr())) - } else { - v.Set(reflect.Append(v, m.strct)) - } - } - - return nil -} - -func (m *m2mModel) modelIDMap(b []byte) ([]byte, error) { - for i, col := range m.rel.M2MBaseFKs { - if i > 0 { - b = append(b, ',') - } - if s, ok := m.columns[col]; ok { - b = append(b, s...) - } else { - return nil, fmt.Errorf("pg: %s does not have column=%q", - m.sliceTableModel, col) - } - } - return b, nil -} - -func (m *m2mModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { - if n > 0 { - b, err := rd.ReadFullTemp() - if err != nil { - return err - } - - m.columns[col.Name] = string(b) - rd = pool.NewBytesReader(b) - } else { - m.columns[col.Name] = "" - } - - if ok, err := m.sliceTableModel.scanColumn(col, rd, n); ok { - return err - } - return nil -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_many.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_many.go deleted file mode 100644 index 561384bba..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model_table_many.go +++ /dev/null @@ -1,75 +0,0 @@ -package orm - -import ( - "fmt" - "reflect" -) - -type manyModel struct { - *sliceTableModel - baseTable *Table - rel *Relation - - buf []byte - dstValues map[string][]reflect.Value -} - -var _ TableModel = (*manyModel)(nil) - -func newManyModel(j *join) *manyModel { - baseTable := j.BaseModel.Table() - joinModel := j.JoinModel.(*sliceTableModel) - dstValues := dstValues(joinModel, j.Rel.BaseFKs) - if len(dstValues) == 0 { - return nil - } - m := manyModel{ - sliceTableModel: joinModel, - baseTable: baseTable, - rel: j.Rel, - - dstValues: dstValues, - } - if !m.sliceOfPtr { - m.strct = reflect.New(m.table.Type).Elem() - } - return &m -} - -func (m *manyModel) NextColumnScanner() ColumnScanner { - if m.sliceOfPtr { - m.strct = reflect.New(m.table.Type).Elem() - } else { - m.strct.Set(m.table.zeroStruct) - } - m.structInited = false - return m -} - -func (m *manyModel) AddColumnScanner(model ColumnScanner) error { - m.buf = modelID(m.buf[:0], m.strct, m.rel.JoinFKs) - dstValues, ok := m.dstValues[string(m.buf)] - if !ok { - return fmt.Errorf( - "pg: relation=%q does not have base %s with id=%q (check join conditions)", - m.rel.Field.GoName, m.baseTable, m.buf) - } - - for i, v := range dstValues { - if !m.sliceOfPtr { - v.Set(reflect.Append(v, m.strct)) - continue - } - - if i == 0 { - v.Set(reflect.Append(v, m.strct.Addr())) - continue - } - - clone := reflect.New(m.strct.Type()).Elem() - clone.Set(m.strct) - v.Set(reflect.Append(v, clone.Addr())) - } - - return nil -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_slice.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_slice.go deleted file mode 100644 index c50be8252..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model_table_slice.go +++ /dev/null @@ -1,156 +0,0 @@ -package orm - -import ( - "context" - "reflect" - - "github.com/go-pg/pg/v10/internal" -) - -type sliceTableModel struct { - structTableModel - - slice reflect.Value - sliceLen int - sliceOfPtr bool - nextElem func() reflect.Value -} - -var _ TableModel = (*sliceTableModel)(nil) - -func newSliceTableModel(slice reflect.Value, elemType reflect.Type) *sliceTableModel { - m := &sliceTableModel{ - structTableModel: structTableModel{ - table: GetTable(elemType), - root: slice, - }, - slice: slice, - sliceLen: slice.Len(), - nextElem: internal.MakeSliceNextElemFunc(slice), - } - m.init(slice.Type()) - return m -} - -func (m *sliceTableModel) init(sliceType reflect.Type) { - switch sliceType.Elem().Kind() { - case reflect.Ptr, reflect.Interface: - m.sliceOfPtr = true - } -} - -//nolint -func (*sliceTableModel) useQueryOne() {} - -func (m *sliceTableModel) IsNil() bool { - return false -} - -func (m *sliceTableModel) AppendParam(fmter QueryFormatter, b []byte, name string) ([]byte, bool) { - if field, ok := m.table.FieldsMap[name]; ok { - b = append(b, "_data."...) - b = append(b, field.Column...) - return b, true - } - return m.structTableModel.AppendParam(fmter, b, name) -} - -func (m *sliceTableModel) Join(name string, apply func(*Query) (*Query, error)) *join { - return m.join(m.Value(), name, apply) -} - -func (m *sliceTableModel) Bind(bind reflect.Value) { - m.slice = bind.Field(m.index[len(m.index)-1]) -} - -func (m *sliceTableModel) Kind() reflect.Kind { - return reflect.Slice -} - -func (m *sliceTableModel) Value() reflect.Value { - return m.slice -} - -func (m *sliceTableModel) Init() error { - if m.slice.IsValid() && m.slice.Len() > 0 { - m.slice.Set(m.slice.Slice(0, 0)) - } - return nil -} - -func (m *sliceTableModel) NextColumnScanner() ColumnScanner { - m.strct = m.nextElem() - m.structInited = false - return m -} - -func (m *sliceTableModel) AddColumnScanner(_ ColumnScanner) error { - return nil -} - -// Inherit these hooks from structTableModel. -var ( - _ BeforeScanHook = (*sliceTableModel)(nil) - _ AfterScanHook = (*sliceTableModel)(nil) -) - -func (m *sliceTableModel) AfterSelect(ctx context.Context) error { - if m.table.hasFlag(afterSelectHookFlag) { - return callAfterSelectHookSlice(ctx, m.slice, m.sliceOfPtr) - } - return nil -} - -func (m *sliceTableModel) BeforeInsert(ctx context.Context) (context.Context, error) { - if m.table.hasFlag(beforeInsertHookFlag) { - return callBeforeInsertHookSlice(ctx, m.slice, m.sliceOfPtr) - } - return ctx, nil -} - -func (m *sliceTableModel) AfterInsert(ctx context.Context) error { - if m.table.hasFlag(afterInsertHookFlag) { - return callAfterInsertHookSlice(ctx, m.slice, m.sliceOfPtr) - } - return nil -} - -func (m *sliceTableModel) BeforeUpdate(ctx context.Context) (context.Context, error) { - if m.table.hasFlag(beforeUpdateHookFlag) && !m.IsNil() { - return callBeforeUpdateHookSlice(ctx, m.slice, m.sliceOfPtr) - } - return ctx, nil -} - -func (m *sliceTableModel) AfterUpdate(ctx context.Context) error { - if m.table.hasFlag(afterUpdateHookFlag) { - return callAfterUpdateHookSlice(ctx, m.slice, m.sliceOfPtr) - } - return nil -} - -func (m *sliceTableModel) BeforeDelete(ctx context.Context) (context.Context, error) { - if m.table.hasFlag(beforeDeleteHookFlag) && !m.IsNil() { - return callBeforeDeleteHookSlice(ctx, m.slice, m.sliceOfPtr) - } - return ctx, nil -} - -func (m *sliceTableModel) AfterDelete(ctx context.Context) error { - if m.table.hasFlag(afterDeleteHookFlag) && !m.IsNil() { - return callAfterDeleteHookSlice(ctx, m.slice, m.sliceOfPtr) - } - return nil -} - -func (m *sliceTableModel) setSoftDeleteField() error { - sliceLen := m.slice.Len() - for i := 0; i < sliceLen; i++ { - strct := indirect(m.slice.Index(i)) - fv := m.table.SoftDeleteField.Value(strct) - if err := m.table.SetSoftDeleteField(fv); err != nil { - return err - } - } - return nil -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/model_table_struct.go b/vendor/github.com/go-pg/pg/v10/orm/model_table_struct.go deleted file mode 100644 index fce7cc6b7..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/model_table_struct.go +++ /dev/null @@ -1,399 +0,0 @@ -package orm - -import ( - "context" - "fmt" - "reflect" - "strings" - - "github.com/go-pg/pg/v10/types" -) - -type structTableModel struct { - table *Table - rel *Relation - joins []join - - root reflect.Value - index []int - - strct reflect.Value - structInited bool - structInitErr error -} - -var _ TableModel = (*structTableModel)(nil) - -func newStructTableModel(table *Table) *structTableModel { - return &structTableModel{ - table: table, - } -} - -func newStructTableModelValue(v reflect.Value) *structTableModel { - return &structTableModel{ - table: GetTable(v.Type()), - root: v, - strct: v, - } -} - -func (*structTableModel) useQueryOne() bool { - return true -} - -func (m *structTableModel) String() string { - return m.table.String() -} - -func (m *structTableModel) IsNil() bool { - return !m.strct.IsValid() -} - -func (m *structTableModel) Table() *Table { - return m.table -} - -func (m *structTableModel) Relation() *Relation { - return m.rel -} - -func (m *structTableModel) AppendParam(fmter QueryFormatter, b []byte, name string) ([]byte, bool) { - b, ok := m.table.AppendParam(b, m.strct, name) - if ok { - return b, true - } - - switch name { - case "TableName": - b = fmter.FormatQuery(b, string(m.table.SQLName)) - return b, true - case "TableAlias": - b = append(b, m.table.Alias...) - return b, true - case "TableColumns": - b = appendColumns(b, m.table.Alias, m.table.Fields) - return b, true - case "Columns": - b = appendColumns(b, "", m.table.Fields) - return b, true - case "TablePKs": - b = appendColumns(b, m.table.Alias, m.table.PKs) - return b, true - case "PKs": - b = appendColumns(b, "", m.table.PKs) - return b, true - } - - return b, false -} - -func (m *structTableModel) Root() reflect.Value { - return m.root -} - -func (m *structTableModel) Index() []int { - return m.index -} - -func (m *structTableModel) ParentIndex() []int { - return m.index[:len(m.index)-len(m.rel.Field.Index)] -} - -func (m *structTableModel) Kind() reflect.Kind { - return reflect.Struct -} - -func (m *structTableModel) Value() reflect.Value { - return m.strct -} - -func (m *structTableModel) Mount(host reflect.Value) { - m.strct = host.FieldByIndex(m.rel.Field.Index) - m.structInited = false -} - -func (m *structTableModel) initStruct() error { - if m.structInited { - return m.structInitErr - } - m.structInited = true - - switch m.strct.Kind() { - case reflect.Invalid: - m.structInitErr = errModelNil - return m.structInitErr - case reflect.Interface: - m.strct = m.strct.Elem() - } - - if m.strct.Kind() == reflect.Ptr { - if m.strct.IsNil() { - m.strct.Set(reflect.New(m.strct.Type().Elem())) - m.strct = m.strct.Elem() - } else { - m.strct = m.strct.Elem() - } - } - - m.mountJoins() - - return nil -} - -func (m *structTableModel) mountJoins() { - for i := range m.joins { - j := &m.joins[i] - switch j.Rel.Type { - case HasOneRelation, BelongsToRelation: - j.JoinModel.Mount(m.strct) - } - } -} - -func (structTableModel) Init() error { - return nil -} - -func (m *structTableModel) NextColumnScanner() ColumnScanner { - return m -} - -func (m *structTableModel) AddColumnScanner(_ ColumnScanner) error { - return nil -} - -var _ BeforeScanHook = (*structTableModel)(nil) - -func (m *structTableModel) BeforeScan(ctx context.Context) error { - if !m.table.hasFlag(beforeScanHookFlag) { - return nil - } - return callBeforeScanHook(ctx, m.strct.Addr()) -} - -var _ AfterScanHook = (*structTableModel)(nil) - -func (m *structTableModel) AfterScan(ctx context.Context) error { - if !m.table.hasFlag(afterScanHookFlag) || !m.structInited { - return nil - } - - var firstErr error - - if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil { - firstErr = err - } - - for _, j := range m.joins { - switch j.Rel.Type { - case HasOneRelation, BelongsToRelation: - if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil { - firstErr = err - } - } - } - - return firstErr -} - -func (m *structTableModel) AfterSelect(ctx context.Context) error { - if m.table.hasFlag(afterSelectHookFlag) { - return callAfterSelectHook(ctx, m.strct.Addr()) - } - return nil -} - -func (m *structTableModel) BeforeInsert(ctx context.Context) (context.Context, error) { - if m.table.hasFlag(beforeInsertHookFlag) { - return callBeforeInsertHook(ctx, m.strct.Addr()) - } - return ctx, nil -} - -func (m *structTableModel) AfterInsert(ctx context.Context) error { - if m.table.hasFlag(afterInsertHookFlag) { - return callAfterInsertHook(ctx, m.strct.Addr()) - } - return nil -} - -func (m *structTableModel) BeforeUpdate(ctx context.Context) (context.Context, error) { - if m.table.hasFlag(beforeUpdateHookFlag) && !m.IsNil() { - return callBeforeUpdateHook(ctx, m.strct.Addr()) - } - return ctx, nil -} - -func (m *structTableModel) AfterUpdate(ctx context.Context) error { - if m.table.hasFlag(afterUpdateHookFlag) && !m.IsNil() { - return callAfterUpdateHook(ctx, m.strct.Addr()) - } - return nil -} - -func (m *structTableModel) BeforeDelete(ctx context.Context) (context.Context, error) { - if m.table.hasFlag(beforeDeleteHookFlag) && !m.IsNil() { - return callBeforeDeleteHook(ctx, m.strct.Addr()) - } - return ctx, nil -} - -func (m *structTableModel) AfterDelete(ctx context.Context) error { - if m.table.hasFlag(afterDeleteHookFlag) && !m.IsNil() { - return callAfterDeleteHook(ctx, m.strct.Addr()) - } - return nil -} - -func (m *structTableModel) ScanColumn( - col types.ColumnInfo, rd types.Reader, n int, -) error { - ok, err := m.scanColumn(col, rd, n) - if ok { - return err - } - if m.table.hasFlag(discardUnknownColumnsFlag) || col.Name[0] == '_' { - return nil - } - return fmt.Errorf( - "pg: can't find column=%s in %s "+ - "(prefix the column with underscore or use discard_unknown_columns)", - col.Name, m.table, - ) -} - -func (m *structTableModel) scanColumn(col types.ColumnInfo, rd types.Reader, n int) (bool, error) { - // Don't init nil struct if value is NULL. - if n == -1 && - !m.structInited && - m.strct.Kind() == reflect.Ptr && - m.strct.IsNil() { - return true, nil - } - - if err := m.initStruct(); err != nil { - return true, err - } - - joinName, fieldName := splitColumn(col.Name) - if joinName != "" { - if join := m.GetJoin(joinName); join != nil { - joinCol := col - joinCol.Name = fieldName - return join.JoinModel.scanColumn(joinCol, rd, n) - } - if m.table.ModelName == joinName { - joinCol := col - joinCol.Name = fieldName - return m.scanColumn(joinCol, rd, n) - } - } - - field, ok := m.table.FieldsMap[col.Name] - if !ok { - return false, nil - } - - return true, field.ScanValue(m.strct, rd, n) -} - -func (m *structTableModel) GetJoin(name string) *join { - for i := range m.joins { - j := &m.joins[i] - if j.Rel.Field.GoName == name || j.Rel.Field.SQLName == name { - return j - } - } - return nil -} - -func (m *structTableModel) GetJoins() []join { - return m.joins -} - -func (m *structTableModel) AddJoin(j join) *join { - m.joins = append(m.joins, j) - return &m.joins[len(m.joins)-1] -} - -func (m *structTableModel) Join(name string, apply func(*Query) (*Query, error)) *join { - return m.join(m.Value(), name, apply) -} - -func (m *structTableModel) join( - bind reflect.Value, name string, apply func(*Query) (*Query, error), -) *join { - path := strings.Split(name, ".") - index := make([]int, 0, len(path)) - - currJoin := join{ - BaseModel: m, - JoinModel: m, - } - var lastJoin *join - var hasColumnName bool - - for _, name := range path { - rel, ok := currJoin.JoinModel.Table().Relations[name] - if !ok { - hasColumnName = true - break - } - - currJoin.Rel = rel - index = append(index, rel.Field.Index...) - - if j := currJoin.JoinModel.GetJoin(name); j != nil { - currJoin.BaseModel = j.BaseModel - currJoin.JoinModel = j.JoinModel - - lastJoin = j - } else { - model, err := newTableModelIndex(m.table.Type, bind, index, rel) - if err != nil { - return nil - } - - currJoin.Parent = lastJoin - currJoin.BaseModel = currJoin.JoinModel - currJoin.JoinModel = model - - lastJoin = currJoin.BaseModel.AddJoin(currJoin) - } - } - - // No joins with such name. - if lastJoin == nil { - return nil - } - if apply != nil { - lastJoin.ApplyQuery = apply - } - - if hasColumnName { - column := path[len(path)-1] - if column == "_" { - if lastJoin.Columns == nil { - lastJoin.Columns = make([]string, 0) - } - } else { - lastJoin.Columns = append(lastJoin.Columns, column) - } - } - - return lastJoin -} - -func (m *structTableModel) setSoftDeleteField() error { - fv := m.table.SoftDeleteField.Value(m.strct) - return m.table.SetSoftDeleteField(fv) -} - -func splitColumn(s string) (string, string) { - ind := strings.Index(s, "__") - if ind == -1 { - return "", s - } - return s[:ind], s[ind+2:] -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/msgpack.go b/vendor/github.com/go-pg/pg/v10/orm/msgpack.go deleted file mode 100644 index 56c88a23e..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/msgpack.go +++ /dev/null @@ -1,52 +0,0 @@ -package orm - -import ( - "reflect" - - "github.com/vmihailenco/msgpack/v5" - - "github.com/go-pg/pg/v10/types" -) - -func msgpackAppender(_ reflect.Type) types.AppenderFunc { - return func(b []byte, v reflect.Value, flags int) []byte { - hexEnc := types.NewHexEncoder(b, flags) - - enc := msgpack.GetEncoder() - defer msgpack.PutEncoder(enc) - - enc.Reset(hexEnc) - if err := enc.EncodeValue(v); err != nil { - return types.AppendError(b, err) - } - - if err := hexEnc.Close(); err != nil { - return types.AppendError(b, err) - } - - return hexEnc.Bytes() - } -} - -func msgpackScanner(_ reflect.Type) types.ScannerFunc { - return func(v reflect.Value, rd types.Reader, n int) error { - if n <= 0 { - return nil - } - - hexDec, err := types.NewHexDecoder(rd, n) - if err != nil { - return err - } - - dec := msgpack.GetDecoder() - defer msgpack.PutDecoder(dec) - - dec.Reset(hexDec) - if err := dec.DecodeValue(v); err != nil { - return err - } - - return nil - } -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/orm.go b/vendor/github.com/go-pg/pg/v10/orm/orm.go deleted file mode 100644 index d18993d2d..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/orm.go +++ /dev/null @@ -1,58 +0,0 @@ -/* -The API in this package is not stable and may change without any notice. -*/ -package orm - -import ( - "context" - "io" - - "github.com/go-pg/pg/v10/types" -) - -// ColumnScanner is used to scan column values. -type ColumnScanner interface { - // Scan assigns a column value from a row. - // - // An error should be returned if the value can not be stored - // without loss of information. - ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error -} - -type QueryAppender interface { - AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) -} - -type TemplateAppender interface { - AppendTemplate(b []byte) ([]byte, error) -} - -type QueryCommand interface { - QueryAppender - TemplateAppender - String() string - Operation() QueryOp - Clone() QueryCommand - Query() *Query -} - -// DB is a common interface for pg.DB and pg.Tx types. -type DB interface { - Model(model ...interface{}) *Query - ModelContext(c context.Context, model ...interface{}) *Query - - Exec(query interface{}, params ...interface{}) (Result, error) - ExecContext(c context.Context, query interface{}, params ...interface{}) (Result, error) - ExecOne(query interface{}, params ...interface{}) (Result, error) - ExecOneContext(c context.Context, query interface{}, params ...interface{}) (Result, error) - Query(model, query interface{}, params ...interface{}) (Result, error) - QueryContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) - QueryOne(model, query interface{}, params ...interface{}) (Result, error) - QueryOneContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) - - CopyFrom(r io.Reader, query interface{}, params ...interface{}) (Result, error) - CopyTo(w io.Writer, query interface{}, params ...interface{}) (Result, error) - - Context() context.Context - Formatter() QueryFormatter -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/query.go b/vendor/github.com/go-pg/pg/v10/orm/query.go deleted file mode 100644 index 8a9231f65..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/query.go +++ /dev/null @@ -1,1680 +0,0 @@ -package orm - -import ( - "context" - "errors" - "fmt" - "io" - "reflect" - "strconv" - "strings" - "sync" - "time" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/types" -) - -type QueryOp string - -const ( - SelectOp QueryOp = "SELECT" - InsertOp QueryOp = "INSERT" - UpdateOp QueryOp = "UPDATE" - DeleteOp QueryOp = "DELETE" - CreateTableOp QueryOp = "CREATE TABLE" - DropTableOp QueryOp = "DROP TABLE" - CreateCompositeOp QueryOp = "CREATE COMPOSITE" - DropCompositeOp QueryOp = "DROP COMPOSITE" -) - -type queryFlag uint8 - -const ( - implicitModelFlag queryFlag = 1 << iota - deletedFlag - allWithDeletedFlag -) - -type withQuery struct { - name string - query QueryAppender -} - -type columnValue struct { - column string - value *SafeQueryAppender -} - -type union struct { - expr string - query *Query -} - -type Query struct { - ctx context.Context - db DB - stickyErr error - - model Model - tableModel TableModel - flags queryFlag - - with []withQuery - tables []QueryAppender - distinctOn []*SafeQueryAppender - columns []QueryAppender - set []QueryAppender - modelValues map[string]*SafeQueryAppender - extraValues []*columnValue - where []queryWithSepAppender - updWhere []queryWithSepAppender - group []QueryAppender - having []*SafeQueryAppender - union []*union - joins []QueryAppender - joinAppendOn func(app *condAppender) - order []QueryAppender - limit int - offset int - selFor *SafeQueryAppender - - onConflict *SafeQueryAppender - returning []*SafeQueryAppender -} - -func NewQuery(db DB, model ...interface{}) *Query { - ctx := context.Background() - if db != nil { - ctx = db.Context() - } - q := &Query{ctx: ctx} - return q.DB(db).Model(model...) -} - -func NewQueryContext(ctx context.Context, db DB, model ...interface{}) *Query { - return NewQuery(db, model...).Context(ctx) -} - -// New returns new zero Query bound to the current db. -func (q *Query) New() *Query { - clone := &Query{ - ctx: q.ctx, - db: q.db, - - model: q.model, - tableModel: cloneTableModelJoins(q.tableModel), - flags: q.flags, - } - return clone.withFlag(implicitModelFlag) -} - -// Clone clones the Query. -func (q *Query) Clone() *Query { - var modelValues map[string]*SafeQueryAppender - if len(q.modelValues) > 0 { - modelValues = make(map[string]*SafeQueryAppender, len(q.modelValues)) - for k, v := range q.modelValues { - modelValues[k] = v - } - } - - clone := &Query{ - ctx: q.ctx, - db: q.db, - stickyErr: q.stickyErr, - - model: q.model, - tableModel: cloneTableModelJoins(q.tableModel), - flags: q.flags, - - with: q.with[:len(q.with):len(q.with)], - tables: q.tables[:len(q.tables):len(q.tables)], - distinctOn: q.distinctOn[:len(q.distinctOn):len(q.distinctOn)], - columns: q.columns[:len(q.columns):len(q.columns)], - set: q.set[:len(q.set):len(q.set)], - modelValues: modelValues, - extraValues: q.extraValues[:len(q.extraValues):len(q.extraValues)], - where: q.where[:len(q.where):len(q.where)], - updWhere: q.updWhere[:len(q.updWhere):len(q.updWhere)], - joins: q.joins[:len(q.joins):len(q.joins)], - group: q.group[:len(q.group):len(q.group)], - having: q.having[:len(q.having):len(q.having)], - union: q.union[:len(q.union):len(q.union)], - order: q.order[:len(q.order):len(q.order)], - limit: q.limit, - offset: q.offset, - selFor: q.selFor, - - onConflict: q.onConflict, - returning: q.returning[:len(q.returning):len(q.returning)], - } - - return clone -} - -func cloneTableModelJoins(tm TableModel) TableModel { - switch tm := tm.(type) { - case *structTableModel: - if len(tm.joins) == 0 { - return tm - } - clone := *tm - clone.joins = clone.joins[:len(clone.joins):len(clone.joins)] - return &clone - case *sliceTableModel: - if len(tm.joins) == 0 { - return tm - } - clone := *tm - clone.joins = clone.joins[:len(clone.joins):len(clone.joins)] - return &clone - } - return tm -} - -func (q *Query) err(err error) *Query { - if q.stickyErr == nil { - q.stickyErr = err - } - return q -} - -func (q *Query) hasFlag(flag queryFlag) bool { - return hasFlag(q.flags, flag) -} - -func hasFlag(flags, flag queryFlag) bool { - return flags&flag != 0 -} - -func (q *Query) withFlag(flag queryFlag) *Query { - q.flags |= flag - return q -} - -func (q *Query) withoutFlag(flag queryFlag) *Query { - q.flags &= ^flag - return q -} - -func (q *Query) Context(c context.Context) *Query { - q.ctx = c - return q -} - -func (q *Query) DB(db DB) *Query { - q.db = db - return q -} - -func (q *Query) Model(model ...interface{}) *Query { - var err error - switch l := len(model); { - case l == 0: - q.model = nil - case l == 1: - q.model, err = NewModel(model[0]) - case l > 1: - q.model, err = NewModel(&model) - default: - panic("not reached") - } - if err != nil { - q = q.err(err) - } - - q.tableModel, _ = q.model.(TableModel) - - return q.withoutFlag(implicitModelFlag) -} - -func (q *Query) TableModel() TableModel { - return q.tableModel -} - -func (q *Query) isSoftDelete() bool { - if q.tableModel != nil { - return q.tableModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag) - } - return false -} - -// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models. -func (q *Query) Deleted() *Query { - if q.tableModel != nil { - if err := q.tableModel.Table().mustSoftDelete(); err != nil { - return q.err(err) - } - } - return q.withFlag(deletedFlag).withoutFlag(allWithDeletedFlag) -} - -// AllWithDeleted changes query to return all rows including soft deleted ones. -func (q *Query) AllWithDeleted() *Query { - if q.tableModel != nil { - if err := q.tableModel.Table().mustSoftDelete(); err != nil { - return q.err(err) - } - } - return q.withFlag(allWithDeletedFlag).withoutFlag(deletedFlag) -} - -// With adds subq as common table expression with the given name. -func (q *Query) With(name string, subq *Query) *Query { - return q._with(name, NewSelectQuery(subq)) -} - -func (q *Query) WithInsert(name string, subq *Query) *Query { - return q._with(name, NewInsertQuery(subq)) -} - -func (q *Query) WithUpdate(name string, subq *Query) *Query { - return q._with(name, NewUpdateQuery(subq, false)) -} - -func (q *Query) WithDelete(name string, subq *Query) *Query { - return q._with(name, NewDeleteQuery(subq)) -} - -func (q *Query) _with(name string, subq QueryAppender) *Query { - q.with = append(q.with, withQuery{ - name: name, - query: subq, - }) - return q -} - -// WrapWith creates new Query and adds to it current query as -// common table expression with the given name. -func (q *Query) WrapWith(name string) *Query { - wrapper := q.New() - wrapper.with = q.with - q.with = nil - wrapper = wrapper.With(name, q) - return wrapper -} - -func (q *Query) Table(tables ...string) *Query { - for _, table := range tables { - q.tables = append(q.tables, fieldAppender{table}) - } - return q -} - -func (q *Query) TableExpr(expr string, params ...interface{}) *Query { - q.tables = append(q.tables, SafeQuery(expr, params...)) - return q -} - -func (q *Query) Distinct() *Query { - q.distinctOn = make([]*SafeQueryAppender, 0) - return q -} - -func (q *Query) DistinctOn(expr string, params ...interface{}) *Query { - q.distinctOn = append(q.distinctOn, SafeQuery(expr, params...)) - return q -} - -// Column adds a column to the Query quoting it according to PostgreSQL rules. -// Does not expand params like ?TableAlias etc. -// ColumnExpr can be used to bypass quoting restriction or for params expansion. -// Column name can be: -// - column_name, -// - table_alias.column_name, -// - table_alias.*. -func (q *Query) Column(columns ...string) *Query { - for _, column := range columns { - if column == "_" { - if q.columns == nil { - q.columns = make([]QueryAppender, 0) - } - continue - } - - q.columns = append(q.columns, fieldAppender{column}) - } - return q -} - -// ColumnExpr adds column expression to the Query. -func (q *Query) ColumnExpr(expr string, params ...interface{}) *Query { - q.columns = append(q.columns, SafeQuery(expr, params...)) - return q -} - -// ExcludeColumn excludes a column from the list of to be selected columns. -func (q *Query) ExcludeColumn(columns ...string) *Query { - if q.columns == nil { - for _, f := range q.tableModel.Table().Fields { - q.columns = append(q.columns, fieldAppender{f.SQLName}) - } - } - - for _, col := range columns { - if !q.excludeColumn(col) { - return q.err(fmt.Errorf("pg: can't find column=%q", col)) - } - } - return q -} - -func (q *Query) excludeColumn(column string) bool { - for i := 0; i < len(q.columns); i++ { - app, ok := q.columns[i].(fieldAppender) - if ok && app.field == column { - q.columns = append(q.columns[:i], q.columns[i+1:]...) - return true - } - } - return false -} - -func (q *Query) getFields() ([]*Field, error) { - return q._getFields(false) -} - -func (q *Query) getDataFields() ([]*Field, error) { - return q._getFields(true) -} - -func (q *Query) _getFields(omitPKs bool) ([]*Field, error) { - table := q.tableModel.Table() - columns := make([]*Field, 0, len(q.columns)) - for _, col := range q.columns { - f, ok := col.(fieldAppender) - if !ok { - continue - } - - field, err := table.GetField(f.field) - if err != nil { - return nil, err - } - - if omitPKs && field.hasFlag(PrimaryKeyFlag) { - continue - } - - columns = append(columns, field) - } - return columns, nil -} - -// Relation adds a relation to the query. Relation name can be: -// - RelationName to select all columns, -// - RelationName.column_name, -// - RelationName._ to join relation without selecting relation columns. -func (q *Query) Relation(name string, apply ...func(*Query) (*Query, error)) *Query { - var fn func(*Query) (*Query, error) - if len(apply) == 1 { - fn = apply[0] - } else if len(apply) > 1 { - panic("only one apply function is supported") - } - - join := q.tableModel.Join(name, fn) - if join == nil { - return q.err(fmt.Errorf("%s does not have relation=%q", - q.tableModel.Table(), name)) - } - - if fn == nil { - return q - } - - switch join.Rel.Type { - case HasOneRelation, BelongsToRelation: - q.joinAppendOn = join.AppendOn - return q.Apply(fn) - default: - q.joinAppendOn = nil - return q - } -} - -func (q *Query) Set(set string, params ...interface{}) *Query { - q.set = append(q.set, SafeQuery(set, params...)) - return q -} - -// Value overwrites model value for the column in INSERT and UPDATE queries. -func (q *Query) Value(column string, value string, params ...interface{}) *Query { - if !q.hasTableModel() { - q.err(errModelNil) - return q - } - - table := q.tableModel.Table() - if _, ok := table.FieldsMap[column]; ok { - if q.modelValues == nil { - q.modelValues = make(map[string]*SafeQueryAppender) - } - q.modelValues[column] = SafeQuery(value, params...) - } else { - q.extraValues = append(q.extraValues, &columnValue{ - column: column, - value: SafeQuery(value, params...), - }) - } - - return q -} - -func (q *Query) Where(condition string, params ...interface{}) *Query { - q.addWhere(&condAppender{ - sep: " AND ", - cond: condition, - params: params, - }) - return q -} - -func (q *Query) WhereOr(condition string, params ...interface{}) *Query { - q.addWhere(&condAppender{ - sep: " OR ", - cond: condition, - params: params, - }) - return q -} - -// WhereGroup encloses conditions added in the function in parentheses. -// -// q.Where("TRUE"). -// WhereGroup(func(q *pg.Query) (*pg.Query, error) { -// q = q.WhereOr("FALSE").WhereOr("TRUE"). -// return q, nil -// }) -// -// generates -// -// WHERE TRUE AND (FALSE OR TRUE) -func (q *Query) WhereGroup(fn func(*Query) (*Query, error)) *Query { - return q.whereGroup(" AND ", fn) -} - -// WhereGroup encloses conditions added in the function in parentheses. -// -// q.Where("TRUE"). -// WhereNotGroup(func(q *pg.Query) (*pg.Query, error) { -// q = q.WhereOr("FALSE").WhereOr("TRUE"). -// return q, nil -// }) -// -// generates -// -// WHERE TRUE AND NOT (FALSE OR TRUE) -func (q *Query) WhereNotGroup(fn func(*Query) (*Query, error)) *Query { - return q.whereGroup(" AND NOT ", fn) -} - -// WhereOrGroup encloses conditions added in the function in parentheses. -// -// q.Where("TRUE"). -// WhereOrGroup(func(q *pg.Query) (*pg.Query, error) { -// q = q.Where("FALSE").Where("TRUE"). -// return q, nil -// }) -// -// generates -// -// WHERE TRUE OR (FALSE AND TRUE) -func (q *Query) WhereOrGroup(fn func(*Query) (*Query, error)) *Query { - return q.whereGroup(" OR ", fn) -} - -// WhereOrGroup encloses conditions added in the function in parentheses. -// -// q.Where("TRUE"). -// WhereOrGroup(func(q *pg.Query) (*pg.Query, error) { -// q = q.Where("FALSE").Where("TRUE"). -// return q, nil -// }) -// -// generates -// -// WHERE TRUE OR NOT (FALSE AND TRUE) -func (q *Query) WhereOrNotGroup(fn func(*Query) (*Query, error)) *Query { - return q.whereGroup(" OR NOT ", fn) -} - -func (q *Query) whereGroup(conj string, fn func(*Query) (*Query, error)) *Query { - saved := q.where - q.where = nil - - newq, err := fn(q) - if err != nil { - q.err(err) - return q - } - - if len(newq.where) == 0 { - newq.where = saved - return newq - } - - f := &condGroupAppender{ - sep: conj, - cond: newq.where, - } - newq.where = saved - newq.addWhere(f) - - return newq -} - -// WhereIn is a shortcut for Where and pg.In. -func (q *Query) WhereIn(where string, slice interface{}) *Query { - return q.Where(where, types.In(slice)) -} - -// WhereInMulti is a shortcut for Where and pg.InMulti. -func (q *Query) WhereInMulti(where string, values ...interface{}) *Query { - return q.Where(where, types.InMulti(values...)) -} - -func (q *Query) addWhere(f queryWithSepAppender) { - if q.onConflictDoUpdate() { - q.updWhere = append(q.updWhere, f) - } else { - q.where = append(q.where, f) - } -} - -// WherePK adds condition based on the model primary keys. -// Usually it is the same as: -// -// Where("id = ?id") -func (q *Query) WherePK() *Query { - if !q.hasTableModel() { - q.err(errModelNil) - return q - } - - if err := q.tableModel.Table().checkPKs(); err != nil { - q.err(err) - return q - } - - switch q.tableModel.Kind() { - case reflect.Struct: - q.where = append(q.where, wherePKStructQuery{q}) - return q - case reflect.Slice: - q.joins = append(q.joins, joinPKSliceQuery{q: q}) - q.where = append(q.where, wherePKSliceQuery{q: q}) - q = q.OrderExpr(`"_data"."ordering" ASC`) - return q - } - - panic("not reached") -} - -func (q *Query) Join(join string, params ...interface{}) *Query { - j := &joinQuery{ - join: SafeQuery(join, params...), - } - q.joins = append(q.joins, j) - q.joinAppendOn = j.AppendOn - return q -} - -// JoinOn appends join condition to the last join. -func (q *Query) JoinOn(condition string, params ...interface{}) *Query { - if q.joinAppendOn == nil { - q.err(errors.New("pg: no joins to apply JoinOn")) - return q - } - q.joinAppendOn(&condAppender{ - sep: " AND ", - cond: condition, - params: params, - }) - return q -} - -func (q *Query) JoinOnOr(condition string, params ...interface{}) *Query { - if q.joinAppendOn == nil { - q.err(errors.New("pg: no joins to apply JoinOn")) - return q - } - q.joinAppendOn(&condAppender{ - sep: " OR ", - cond: condition, - params: params, - }) - return q -} - -func (q *Query) Group(columns ...string) *Query { - for _, column := range columns { - q.group = append(q.group, fieldAppender{column}) - } - return q -} - -func (q *Query) GroupExpr(group string, params ...interface{}) *Query { - q.group = append(q.group, SafeQuery(group, params...)) - return q -} - -func (q *Query) Having(having string, params ...interface{}) *Query { - q.having = append(q.having, SafeQuery(having, params...)) - return q -} - -func (q *Query) Union(other *Query) *Query { - return q.addUnion(" UNION ", other) -} - -func (q *Query) UnionAll(other *Query) *Query { - return q.addUnion(" UNION ALL ", other) -} - -func (q *Query) Intersect(other *Query) *Query { - return q.addUnion(" INTERSECT ", other) -} - -func (q *Query) IntersectAll(other *Query) *Query { - return q.addUnion(" INTERSECT ALL ", other) -} - -func (q *Query) Except(other *Query) *Query { - return q.addUnion(" EXCEPT ", other) -} - -func (q *Query) ExceptAll(other *Query) *Query { - return q.addUnion(" EXCEPT ALL ", other) -} - -func (q *Query) addUnion(expr string, other *Query) *Query { - q.union = append(q.union, &union{ - expr: expr, - query: other, - }) - return q -} - -// Order adds sort order to the Query quoting column name. Does not expand params like ?TableAlias etc. -// OrderExpr can be used to bypass quoting restriction or for params expansion. -func (q *Query) Order(orders ...string) *Query { -loop: - for _, order := range orders { - if order == "" { - continue - } - ind := strings.Index(order, " ") - if ind != -1 { - field := order[:ind] - sort := order[ind+1:] - switch internal.UpperString(sort) { - case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST", - "ASC NULLS LAST", "DESC NULLS LAST": - q = q.OrderExpr("? ?", types.Ident(field), types.Safe(sort)) - continue loop - } - } - - q.order = append(q.order, fieldAppender{order}) - } - return q -} - -// Order adds sort order to the Query. -func (q *Query) OrderExpr(order string, params ...interface{}) *Query { - if order != "" { - q.order = append(q.order, SafeQuery(order, params...)) - } - return q -} - -func (q *Query) Limit(n int) *Query { - q.limit = n - return q -} - -func (q *Query) Offset(n int) *Query { - q.offset = n - return q -} - -func (q *Query) OnConflict(s string, params ...interface{}) *Query { - q.onConflict = SafeQuery(s, params...) - return q -} - -func (q *Query) onConflictDoUpdate() bool { - return q.onConflict != nil && - strings.HasSuffix(internal.UpperString(q.onConflict.query), "DO UPDATE") -} - -// Returning adds a RETURNING clause to the query. -// -// `Returning("NULL")` can be used to suppress default returning clause -// generated by go-pg for INSERT queries to get values for null columns. -func (q *Query) Returning(s string, params ...interface{}) *Query { - q.returning = append(q.returning, SafeQuery(s, params...)) - return q -} - -func (q *Query) For(s string, params ...interface{}) *Query { - q.selFor = SafeQuery(s, params...) - return q -} - -// Apply calls the fn passing the Query as an argument. -func (q *Query) Apply(fn func(*Query) (*Query, error)) *Query { - qq, err := fn(q) - if err != nil { - q.err(err) - return q - } - return qq -} - -// Count returns number of rows matching the query using count aggregate function. -func (q *Query) Count() (int, error) { - if q.stickyErr != nil { - return 0, q.stickyErr - } - - var count int - _, err := q.db.QueryOneContext( - q.ctx, Scan(&count), q.countSelectQuery("count(*)"), q.tableModel) - return count, err -} - -func (q *Query) countSelectQuery(column string) *SelectQuery { - return &SelectQuery{ - q: q, - count: column, - } -} - -// First sorts rows by primary key and selects the first row. -// It is a shortcut for: -// -// q.OrderExpr("id ASC").Limit(1) -func (q *Query) First() error { - table := q.tableModel.Table() - - if err := table.checkPKs(); err != nil { - return err - } - - b := appendColumns(nil, table.Alias, table.PKs) - return q.OrderExpr(internal.BytesToString(b)).Limit(1).Select() -} - -// Last sorts rows by primary key and selects the last row. -// It is a shortcut for: -// -// q.OrderExpr("id DESC").Limit(1) -func (q *Query) Last() error { - table := q.tableModel.Table() - - if err := table.checkPKs(); err != nil { - return err - } - - // TODO: fix for multi columns - b := appendColumns(nil, table.Alias, table.PKs) - b = append(b, " DESC"...) - return q.OrderExpr(internal.BytesToString(b)).Limit(1).Select() -} - -// Select selects the model. -func (q *Query) Select(values ...interface{}) error { - if q.stickyErr != nil { - return q.stickyErr - } - - model, err := q.newModel(values) - if err != nil { - return err - } - - res, err := q.query(q.ctx, model, NewSelectQuery(q)) - if err != nil { - return err - } - - if res.RowsReturned() > 0 { - if q.tableModel != nil { - if err := q.selectJoins(q.tableModel.GetJoins()); err != nil { - return err - } - } - } - - if err := model.AfterSelect(q.ctx); err != nil { - return err - } - - return nil -} - -func (q *Query) newModel(values []interface{}) (Model, error) { - if len(values) > 0 { - return newScanModel(values) - } - return q.tableModel, nil -} - -func (q *Query) query(ctx context.Context, model Model, query interface{}) (Result, error) { - if _, ok := model.(useQueryOne); ok { - return q.db.QueryOneContext(ctx, model, query, q.tableModel) - } - return q.db.QueryContext(ctx, model, query, q.tableModel) -} - -// SelectAndCount runs Select and Count in two goroutines, -// waits for them to finish and returns the result. If query limit is -1 -// it does not select any data and only counts the results. -func (q *Query) SelectAndCount(values ...interface{}) (count int, firstErr error) { - if q.stickyErr != nil { - return 0, q.stickyErr - } - - var wg sync.WaitGroup - var mu sync.Mutex - - if q.limit >= 0 { - wg.Add(1) - go func() { - defer wg.Done() - err := q.Select(values...) - if err != nil { - mu.Lock() - if firstErr == nil { - firstErr = err - } - mu.Unlock() - } - }() - } - - wg.Add(1) - go func() { - defer wg.Done() - var err error - count, err = q.Count() - if err != nil { - mu.Lock() - if firstErr == nil { - firstErr = err - } - mu.Unlock() - } - }() - - wg.Wait() - return count, firstErr -} - -// SelectAndCountEstimate runs Select and CountEstimate in two goroutines, -// waits for them to finish and returns the result. If query limit is -1 -// it does not select any data and only counts the results. -func (q *Query) SelectAndCountEstimate(threshold int, values ...interface{}) (count int, firstErr error) { - if q.stickyErr != nil { - return 0, q.stickyErr - } - - var wg sync.WaitGroup - var mu sync.Mutex - - if q.limit >= 0 { - wg.Add(1) - go func() { - defer wg.Done() - err := q.Select(values...) - if err != nil { - mu.Lock() - if firstErr == nil { - firstErr = err - } - mu.Unlock() - } - }() - } - - wg.Add(1) - go func() { - defer wg.Done() - var err error - count, err = q.CountEstimate(threshold) - if err != nil { - mu.Lock() - if firstErr == nil { - firstErr = err - } - mu.Unlock() - } - }() - - wg.Wait() - return count, firstErr -} - -// ForEach calls the function for each row returned by the query -// without loading all rows into the memory. -// -// Function can accept a struct, a pointer to a struct, an orm.Model, -// or values for the columns in a row. Function must return an error. -func (q *Query) ForEach(fn interface{}) error { - m := newFuncModel(fn) - return q.Select(m) -} - -func (q *Query) forEachHasOneJoin(fn func(*join) error) error { - if q.tableModel == nil { - return nil - } - return q._forEachHasOneJoin(fn, q.tableModel.GetJoins()) -} - -func (q *Query) _forEachHasOneJoin(fn func(*join) error, joins []join) error { - for i := range joins { - j := &joins[i] - switch j.Rel.Type { - case HasOneRelation, BelongsToRelation: - err := fn(j) - if err != nil { - return err - } - - err = q._forEachHasOneJoin(fn, j.JoinModel.GetJoins()) - if err != nil { - return err - } - } - } - return nil -} - -func (q *Query) selectJoins(joins []join) error { - var err error - for i := range joins { - j := &joins[i] - if j.Rel.Type == HasOneRelation || j.Rel.Type == BelongsToRelation { - err = q.selectJoins(j.JoinModel.GetJoins()) - } else { - err = j.Select(q.db.Formatter(), q.New()) - } - if err != nil { - return err - } - } - return nil -} - -// Insert inserts the model. -func (q *Query) Insert(values ...interface{}) (Result, error) { - if q.stickyErr != nil { - return nil, q.stickyErr - } - - model, err := q.newModel(values) - if err != nil { - return nil, err - } - - ctx := q.ctx - - if q.tableModel != nil && q.tableModel.Table().hasFlag(beforeInsertHookFlag) { - ctx, err = q.tableModel.BeforeInsert(ctx) - if err != nil { - return nil, err - } - } - - query := NewInsertQuery(q) - res, err := q.returningQuery(ctx, model, query) - if err != nil { - return nil, err - } - - if q.tableModel != nil { - if err := q.tableModel.AfterInsert(ctx); err != nil { - return nil, err - } - } - - return res, nil -} - -// SelectOrInsert selects the model inserting one if it does not exist. -// It returns true when model was inserted. -func (q *Query) SelectOrInsert(values ...interface{}) (inserted bool, _ error) { - if q.stickyErr != nil { - return false, q.stickyErr - } - - var insertq *Query - var insertErr error - for i := 0; i < 5; i++ { - if i >= 2 { - dur := internal.RetryBackoff(i-2, 250*time.Millisecond, 5*time.Second) - if err := internal.Sleep(q.ctx, dur); err != nil { - return false, err - } - } - - err := q.Select(values...) - if err == nil { - return false, nil - } - if err != internal.ErrNoRows { - return false, err - } - - if insertq == nil { - insertq = q - if len(insertq.columns) > 0 { - insertq = insertq.Clone() - insertq.columns = nil - } - } - - res, err := insertq.Insert(values...) - if err != nil { - insertErr = err - if err == internal.ErrNoRows { - continue - } - if pgErr, ok := err.(internal.PGError); ok { - if pgErr.IntegrityViolation() { - continue - } - if pgErr.Field('C') == "55000" { - // Retry on "#55000 attempted to delete invisible tuple". - continue - } - } - return false, err - } - if res.RowsAffected() == 1 { - return true, nil - } - } - - err := fmt.Errorf( - "pg: SelectOrInsert: select returns no rows (insert fails with err=%q)", - insertErr) - return false, err -} - -// Update updates the model. -func (q *Query) Update(scan ...interface{}) (Result, error) { - return q.update(scan, false) -} - -// Update updates the model omitting fields with zero values such as: -// - empty string, -// - 0, -// - zero time, -// - empty map or slice, -// - byte array with all zeroes, -// - nil ptr, -// - types with method `IsZero() == true`. -func (q *Query) UpdateNotZero(scan ...interface{}) (Result, error) { - return q.update(scan, true) -} - -func (q *Query) update(values []interface{}, omitZero bool) (Result, error) { - if q.stickyErr != nil { - return nil, q.stickyErr - } - - model, err := q.newModel(values) - if err != nil { - return nil, err - } - - c := q.ctx - - if q.tableModel != nil { - c, err = q.tableModel.BeforeUpdate(c) - if err != nil { - return nil, err - } - } - - query := NewUpdateQuery(q, omitZero) - res, err := q.returningQuery(c, model, query) - if err != nil { - return nil, err - } - - if q.tableModel != nil { - err = q.tableModel.AfterUpdate(c) - if err != nil { - return nil, err - } - } - - return res, nil -} - -func (q *Query) returningQuery(c context.Context, model Model, query interface{}) (Result, error) { - if !q.hasReturning() { - return q.db.QueryContext(c, model, query, q.tableModel) - } - if _, ok := model.(useQueryOne); ok { - return q.db.QueryOneContext(c, model, query, q.tableModel) - } - return q.db.QueryContext(c, model, query, q.tableModel) -} - -// Delete deletes the model. When model has deleted_at column the row -// is soft deleted instead. -func (q *Query) Delete(values ...interface{}) (Result, error) { - if q.tableModel == nil { - return q.ForceDelete(values...) - } - - table := q.tableModel.Table() - if table.SoftDeleteField == nil { - return q.ForceDelete(values...) - } - - clone := q.Clone() - if q.tableModel.IsNil() { - if table.SoftDeleteField.SQLType == pgTypeBigint { - clone = clone.Set("? = ?", table.SoftDeleteField.Column, time.Now().UnixNano()) - } else { - clone = clone.Set("? = ?", table.SoftDeleteField.Column, time.Now()) - } - } else { - if err := clone.tableModel.setSoftDeleteField(); err != nil { - return nil, err - } - clone = clone.Column(table.SoftDeleteField.SQLName) - } - return clone.Update(values...) -} - -// Delete forces delete of the model with deleted_at column. -func (q *Query) ForceDelete(values ...interface{}) (Result, error) { - if q.stickyErr != nil { - return nil, q.stickyErr - } - if q.tableModel == nil { - return nil, errModelNil - } - q = q.withFlag(deletedFlag) - - model, err := q.newModel(values) - if err != nil { - return nil, err - } - - ctx := q.ctx - - if q.tableModel != nil { - ctx, err = q.tableModel.BeforeDelete(ctx) - if err != nil { - return nil, err - } - } - - res, err := q.returningQuery(ctx, model, NewDeleteQuery(q)) - if err != nil { - return nil, err - } - - if q.tableModel != nil { - if err := q.tableModel.AfterDelete(ctx); err != nil { - return nil, err - } - } - - return res, nil -} - -func (q *Query) CreateTable(opt *CreateTableOptions) error { - _, err := q.db.ExecContext(q.ctx, NewCreateTableQuery(q, opt)) - return err -} - -func (q *Query) DropTable(opt *DropTableOptions) error { - _, err := q.db.ExecContext(q.ctx, NewDropTableQuery(q, opt)) - return err -} - -func (q *Query) CreateComposite(opt *CreateCompositeOptions) error { - _, err := q.db.ExecContext(q.ctx, NewCreateCompositeQuery(q, opt)) - return err -} - -func (q *Query) DropComposite(opt *DropCompositeOptions) error { - _, err := q.db.ExecContext(q.ctx, NewDropCompositeQuery(q, opt)) - return err -} - -// Exec is an alias for DB.Exec. -func (q *Query) Exec(query interface{}, params ...interface{}) (Result, error) { - params = append(params, q.tableModel) - return q.db.ExecContext(q.ctx, query, params...) -} - -// ExecOne is an alias for DB.ExecOne. -func (q *Query) ExecOne(query interface{}, params ...interface{}) (Result, error) { - params = append(params, q.tableModel) - return q.db.ExecOneContext(q.ctx, query, params...) -} - -// Query is an alias for DB.Query. -func (q *Query) Query(model, query interface{}, params ...interface{}) (Result, error) { - params = append(params, q.tableModel) - return q.db.QueryContext(q.ctx, model, query, params...) -} - -// QueryOne is an alias for DB.QueryOne. -func (q *Query) QueryOne(model, query interface{}, params ...interface{}) (Result, error) { - params = append(params, q.tableModel) - return q.db.QueryOneContext(q.ctx, model, query, params...) -} - -// CopyFrom is an alias from DB.CopyFrom. -func (q *Query) CopyFrom(r io.Reader, query interface{}, params ...interface{}) (Result, error) { - params = append(params, q.tableModel) - return q.db.CopyFrom(r, query, params...) -} - -// CopyTo is an alias from DB.CopyTo. -func (q *Query) CopyTo(w io.Writer, query interface{}, params ...interface{}) (Result, error) { - params = append(params, q.tableModel) - return q.db.CopyTo(w, query, params...) -} - -var _ QueryAppender = (*Query)(nil) - -func (q *Query) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { - return NewSelectQuery(q).AppendQuery(fmter, b) -} - -// Exists returns true or false depending if there are any rows matching the query. -func (q *Query) Exists() (bool, error) { - q = q.Clone() // copy to not change original query - q.columns = []QueryAppender{SafeQuery("1")} - q.order = nil - q.limit = 1 - res, err := q.db.ExecContext(q.ctx, NewSelectQuery(q)) - if err != nil { - return false, err - } - return res.RowsAffected() > 0, nil -} - -func (q *Query) hasTableModel() bool { - return q.tableModel != nil && !q.tableModel.IsNil() -} - -func (q *Query) hasExplicitTableModel() bool { - return q.tableModel != nil && !q.hasFlag(implicitModelFlag) -} - -func (q *Query) modelHasTableName() bool { - return q.hasExplicitTableModel() && q.tableModel.Table().SQLName != "" -} - -func (q *Query) modelHasTableAlias() bool { - return q.hasExplicitTableModel() && q.tableModel.Table().Alias != "" -} - -func (q *Query) hasTables() bool { - return q.modelHasTableName() || len(q.tables) > 0 -} - -func (q *Query) appendFirstTable(fmter QueryFormatter, b []byte) ([]byte, error) { - if q.modelHasTableName() { - return fmter.FormatQuery(b, string(q.tableModel.Table().SQLName)), nil - } - if len(q.tables) > 0 { - return q.tables[0].AppendQuery(fmter, b) - } - return b, nil -} - -func (q *Query) appendFirstTableWithAlias(fmter QueryFormatter, b []byte) (_ []byte, err error) { - if q.modelHasTableName() { - table := q.tableModel.Table() - b = fmter.FormatQuery(b, string(table.SQLName)) - if table.Alias != table.SQLName { - b = append(b, " AS "...) - b = append(b, table.Alias...) - } - return b, nil - } - - if len(q.tables) > 0 { - b, err = q.tables[0].AppendQuery(fmter, b) - if err != nil { - return nil, err - } - if q.modelHasTableAlias() { - table := q.tableModel.Table() - if table.Alias != table.SQLName { - b = append(b, " AS "...) - b = append(b, table.Alias...) - } - } - } - - return b, nil -} - -func (q *Query) hasMultiTables() bool { - if q.modelHasTableName() { - return len(q.tables) >= 1 - } - return len(q.tables) >= 2 -} - -func (q *Query) appendOtherTables(fmter QueryFormatter, b []byte) (_ []byte, err error) { - tables := q.tables - if !q.modelHasTableName() { - tables = tables[1:] - } - for i, f := range tables { - if i > 0 { - b = append(b, ", "...) - } - b, err = f.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - return b, nil -} - -func (q *Query) appendColumns(fmter QueryFormatter, b []byte) (_ []byte, err error) { - for i, f := range q.columns { - if i > 0 { - b = append(b, ", "...) - } - b, err = f.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - return b, nil -} - -func (q *Query) mustAppendWhere(fmter QueryFormatter, b []byte) ([]byte, error) { - if len(q.where) == 0 { - err := errors.New( - "pg: Update and Delete queries require Where clause (try WherePK)") - return nil, err - } - return q.appendWhere(fmter, b) -} - -func (q *Query) appendWhere(fmter QueryFormatter, b []byte) (_ []byte, err error) { - isSoftDelete := q.isSoftDelete() - - if len(q.where) > 0 { - if isSoftDelete { - b = append(b, '(') - } - - b, err = q._appendWhere(fmter, b, q.where) - if err != nil { - return nil, err - } - - if isSoftDelete { - b = append(b, ')') - } - } - - if isSoftDelete { - if len(q.where) > 0 { - b = append(b, " AND "...) - } - b = append(b, q.tableModel.Table().Alias...) - b = q.appendSoftDelete(b) - } - - return b, nil -} - -func (q *Query) appendSoftDelete(b []byte) []byte { - b = append(b, '.') - b = append(b, q.tableModel.Table().SoftDeleteField.Column...) - if q.hasFlag(deletedFlag) { - b = append(b, " IS NOT NULL"...) - } else { - b = append(b, " IS NULL"...) - } - return b -} - -func (q *Query) appendUpdWhere(fmter QueryFormatter, b []byte) ([]byte, error) { - return q._appendWhere(fmter, b, q.updWhere) -} - -func (q *Query) _appendWhere( - fmter QueryFormatter, b []byte, where []queryWithSepAppender, -) (_ []byte, err error) { - for i, f := range where { - start := len(b) - - if i > 0 { - b = f.AppendSep(b) - } - - before := len(b) - - b, err = f.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - - if len(b) == before { - b = b[:start] - } - } - return b, nil -} - -func (q *Query) appendSet(fmter QueryFormatter, b []byte) (_ []byte, err error) { - b = append(b, " SET "...) - for i, f := range q.set { - if i > 0 { - b = append(b, ", "...) - } - b, err = f.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - return b, nil -} - -func (q *Query) hasReturning() bool { - if len(q.returning) == 0 { - return false - } - if len(q.returning) == 1 { - switch q.returning[0].query { - case "null", "NULL": - return false - } - } - return true -} - -func (q *Query) appendReturning(fmter QueryFormatter, b []byte) (_ []byte, err error) { - if !q.hasReturning() { - return b, nil - } - - b = append(b, " RETURNING "...) - for i, f := range q.returning { - if i > 0 { - b = append(b, ", "...) - } - b, err = f.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - return b, nil -} - -func (q *Query) appendWith(fmter QueryFormatter, b []byte) (_ []byte, err error) { - b = append(b, "WITH "...) - for i, with := range q.with { - if i > 0 { - b = append(b, ", "...) - } - b = types.AppendIdent(b, with.name, 1) - b = append(b, " AS ("...) - - b, err = with.query.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - - b = append(b, ')') - } - b = append(b, ' ') - return b, nil -} - -func (q *Query) isSliceModelWithData() bool { - if !q.hasTableModel() { - return false - } - m, ok := q.tableModel.(*sliceTableModel) - return ok && m.sliceLen > 0 -} - -//------------------------------------------------------------------------------ - -type wherePKStructQuery struct { - q *Query -} - -var _ queryWithSepAppender = (*wherePKStructQuery)(nil) - -func (wherePKStructQuery) AppendSep(b []byte) []byte { - return append(b, " AND "...) -} - -func (q wherePKStructQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { - table := q.q.tableModel.Table() - value := q.q.tableModel.Value() - return appendColumnAndValue(fmter, b, value, table.Alias, table.PKs), nil -} - -func appendColumnAndValue( - fmter QueryFormatter, b []byte, v reflect.Value, alias types.Safe, fields []*Field, -) []byte { - isPlaceholder := isTemplateFormatter(fmter) - for i, f := range fields { - if i > 0 { - b = append(b, " AND "...) - } - b = append(b, alias...) - b = append(b, '.') - b = append(b, f.Column...) - b = append(b, " = "...) - if isPlaceholder { - b = append(b, '?') - } else { - b = f.AppendValue(b, v, 1) - } - } - return b -} - -//------------------------------------------------------------------------------ - -type wherePKSliceQuery struct { - q *Query -} - -var _ queryWithSepAppender = (*wherePKSliceQuery)(nil) - -func (wherePKSliceQuery) AppendSep(b []byte) []byte { - return append(b, " AND "...) -} - -func (q wherePKSliceQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { - table := q.q.tableModel.Table() - - for i, f := range table.PKs { - if i > 0 { - b = append(b, " AND "...) - } - b = append(b, table.Alias...) - b = append(b, '.') - b = append(b, f.Column...) - b = append(b, " = "...) - b = append(b, `"_data".`...) - b = append(b, f.Column...) - } - - return b, nil -} - -type joinPKSliceQuery struct { - q *Query -} - -var _ QueryAppender = (*joinPKSliceQuery)(nil) - -func (q joinPKSliceQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { - table := q.q.tableModel.Table() - slice := q.q.tableModel.Value() - - b = append(b, " JOIN (VALUES "...) - - sliceLen := slice.Len() - for i := 0; i < sliceLen; i++ { - if i > 0 { - b = append(b, ", "...) - } - - el := indirect(slice.Index(i)) - - b = append(b, '(') - for i, f := range table.PKs { - if i > 0 { - b = append(b, ", "...) - } - - b = f.AppendValue(b, el, 1) - - if f.UserSQLType != "" { - b = append(b, "::"...) - b = append(b, f.SQLType...) - } - } - - b = append(b, ", "...) - b = strconv.AppendInt(b, int64(i), 10) - - b = append(b, ')') - } - - b = append(b, `) AS "_data" (`...) - - for i, f := range table.PKs { - if i > 0 { - b = append(b, ", "...) - } - b = append(b, f.Column...) - } - - b = append(b, ", "...) - b = append(b, `"ordering"`...) - b = append(b, ") ON TRUE"...) - - return b, nil -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/relation.go b/vendor/github.com/go-pg/pg/v10/orm/relation.go deleted file mode 100644 index 28d915bcd..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/relation.go +++ /dev/null @@ -1,33 +0,0 @@ -package orm - -import ( - "fmt" - - "github.com/go-pg/pg/v10/types" -) - -const ( - InvalidRelation = iota - HasOneRelation - BelongsToRelation - HasManyRelation - Many2ManyRelation -) - -type Relation struct { - Type int - Field *Field - JoinTable *Table - BaseFKs []*Field - JoinFKs []*Field - Polymorphic *Field - - M2MTableName types.Safe - M2MTableAlias types.Safe - M2MBaseFKs []string - M2MJoinFKs []string -} - -func (r *Relation) String() string { - return fmt.Sprintf("relation=%s", r.Field.GoName) -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/result.go b/vendor/github.com/go-pg/pg/v10/orm/result.go deleted file mode 100644 index 9d82815ef..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/result.go +++ /dev/null @@ -1,14 +0,0 @@ -package orm - -// Result summarizes an executed SQL command. -type Result interface { - Model() Model - - // RowsAffected returns the number of rows affected by SELECT, INSERT, UPDATE, - // or DELETE queries. It returns -1 if query can't possibly affect any rows, - // e.g. in case of CREATE or SHOW queries. - RowsAffected() int - - // RowsReturned returns the number of rows returned by the query. - RowsReturned() int -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/select.go b/vendor/github.com/go-pg/pg/v10/orm/select.go deleted file mode 100644 index d3b38742d..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/select.go +++ /dev/null @@ -1,346 +0,0 @@ -package orm - -import ( - "bytes" - "fmt" - "strconv" - "strings" - - "github.com/go-pg/pg/v10/types" -) - -type SelectQuery struct { - q *Query - count string -} - -var ( - _ QueryAppender = (*SelectQuery)(nil) - _ QueryCommand = (*SelectQuery)(nil) -) - -func NewSelectQuery(q *Query) *SelectQuery { - return &SelectQuery{ - q: q, - } -} - -func (q *SelectQuery) String() string { - b, err := q.AppendQuery(defaultFmter, nil) - if err != nil { - panic(err) - } - return string(b) -} - -func (q *SelectQuery) Operation() QueryOp { - return SelectOp -} - -func (q *SelectQuery) Clone() QueryCommand { - return &SelectQuery{ - q: q.q.Clone(), - count: q.count, - } -} - -func (q *SelectQuery) Query() *Query { - return q.q -} - -func (q *SelectQuery) AppendTemplate(b []byte) ([]byte, error) { - return q.AppendQuery(dummyFormatter{}, b) -} - -func (q *SelectQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { //nolint:gocyclo - if q.q.stickyErr != nil { - return nil, q.q.stickyErr - } - - cteCount := q.count != "" && (len(q.q.group) > 0 || q.isDistinct()) - if cteCount { - b = append(b, `WITH "_count_wrapper" AS (`...) - } - - if len(q.q.with) > 0 { - b, err = q.q.appendWith(fmter, b) - if err != nil { - return nil, err - } - } - - if len(q.q.union) > 0 { - b = append(b, '(') - } - - b = append(b, "SELECT "...) - - if len(q.q.distinctOn) > 0 { - b = append(b, "DISTINCT ON ("...) - for i, app := range q.q.distinctOn { - if i > 0 { - b = append(b, ", "...) - } - b, err = app.AppendQuery(fmter, b) - } - b = append(b, ") "...) - } else if q.q.distinctOn != nil { - b = append(b, "DISTINCT "...) - } - - if q.count != "" && !cteCount { - b = append(b, q.count...) - } else { - b, err = q.appendColumns(fmter, b) - if err != nil { - return nil, err - } - } - - if q.q.hasTables() { - b = append(b, " FROM "...) - b, err = q.appendTables(fmter, b) - if err != nil { - return nil, err - } - } - - err = q.q.forEachHasOneJoin(func(j *join) error { - b = append(b, ' ') - b, err = j.appendHasOneJoin(fmter, b, q.q) - return err - }) - if err != nil { - return nil, err - } - - for _, j := range q.q.joins { - b, err = j.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - - if len(q.q.where) > 0 || q.q.isSoftDelete() { - b = append(b, " WHERE "...) - b, err = q.q.appendWhere(fmter, b) - if err != nil { - return nil, err - } - } - - if len(q.q.group) > 0 { - b = append(b, " GROUP BY "...) - for i, f := range q.q.group { - if i > 0 { - b = append(b, ", "...) - } - b, err = f.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - } - - if len(q.q.having) > 0 { - b = append(b, " HAVING "...) - for i, f := range q.q.having { - if i > 0 { - b = append(b, " AND "...) - } - b = append(b, '(') - b, err = f.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - b = append(b, ')') - } - } - - if q.count == "" { - if len(q.q.order) > 0 { - b = append(b, " ORDER BY "...) - for i, f := range q.q.order { - if i > 0 { - b = append(b, ", "...) - } - b, err = f.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - } - - if q.q.limit != 0 { - b = append(b, " LIMIT "...) - b = strconv.AppendInt(b, int64(q.q.limit), 10) - } - - if q.q.offset != 0 { - b = append(b, " OFFSET "...) - b = strconv.AppendInt(b, int64(q.q.offset), 10) - } - - if q.q.selFor != nil { - b = append(b, " FOR "...) - b, err = q.q.selFor.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - } else if cteCount { - b = append(b, `) SELECT `...) - b = append(b, q.count...) - b = append(b, ` FROM "_count_wrapper"`...) - } - - if len(q.q.union) > 0 { - b = append(b, ")"...) - - for _, u := range q.q.union { - b = append(b, u.expr...) - b = append(b, '(') - b, err = u.query.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - b = append(b, ")"...) - } - } - - return b, q.q.stickyErr -} - -func (q SelectQuery) appendColumns(fmter QueryFormatter, b []byte) (_ []byte, err error) { - start := len(b) - - switch { - case q.q.columns != nil: - b, err = q.q.appendColumns(fmter, b) - if err != nil { - return nil, err - } - case q.q.hasExplicitTableModel(): - table := q.q.tableModel.Table() - if len(table.Fields) > 10 && isTemplateFormatter(fmter) { - b = append(b, table.Alias...) - b = append(b, '.') - b = types.AppendString(b, fmt.Sprintf("%d columns", len(table.Fields)), 2) - } else { - b = appendColumns(b, table.Alias, table.Fields) - } - default: - b = append(b, '*') - } - - err = q.q.forEachHasOneJoin(func(j *join) error { - if len(b) != start { - b = append(b, ", "...) - start = len(b) - } - - b = j.appendHasOneColumns(b) - return nil - }) - if err != nil { - return nil, err - } - - b = bytes.TrimSuffix(b, []byte(", ")) - - return b, nil -} - -func (q *SelectQuery) isDistinct() bool { - if q.q.distinctOn != nil { - return true - } - for _, column := range q.q.columns { - column, ok := column.(*SafeQueryAppender) - if ok { - if strings.Contains(column.query, "DISTINCT") || - strings.Contains(column.query, "distinct") { - return true - } - } - } - return false -} - -func (q *SelectQuery) appendTables(fmter QueryFormatter, b []byte) (_ []byte, err error) { - tables := q.q.tables - - if q.q.modelHasTableName() { - table := q.q.tableModel.Table() - b = fmter.FormatQuery(b, string(table.SQLNameForSelects)) - if table.Alias != "" { - b = append(b, " AS "...) - b = append(b, table.Alias...) - } - - if len(tables) > 0 { - b = append(b, ", "...) - } - } else if len(tables) > 0 { - b, err = tables[0].AppendQuery(fmter, b) - if err != nil { - return nil, err - } - if q.q.modelHasTableAlias() { - b = append(b, " AS "...) - b = append(b, q.q.tableModel.Table().Alias...) - } - - tables = tables[1:] - if len(tables) > 0 { - b = append(b, ", "...) - } - } - - for i, f := range tables { - if i > 0 { - b = append(b, ", "...) - } - b, err = f.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - - return b, nil -} - -//------------------------------------------------------------------------------ - -type joinQuery struct { - join *SafeQueryAppender - on []*condAppender -} - -func (j *joinQuery) AppendOn(app *condAppender) { - j.on = append(j.on, app) -} - -func (j *joinQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { - b = append(b, ' ') - - b, err = j.join.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - - if len(j.on) > 0 { - b = append(b, " ON "...) - for i, on := range j.on { - if i > 0 { - b = on.AppendSep(b) - } - b, err = on.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - } - - return b, nil -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table.go b/vendor/github.com/go-pg/pg/v10/orm/table.go deleted file mode 100644 index 8b57bbfc0..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/table.go +++ /dev/null @@ -1,1560 +0,0 @@ -package orm - -import ( - "database/sql" - "encoding/json" - "fmt" - "net" - "reflect" - "strconv" - "strings" - "sync" - "time" - - "github.com/jinzhu/inflection" - "github.com/vmihailenco/tagparser" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/internal/pool" - "github.com/go-pg/pg/v10/pgjson" - "github.com/go-pg/pg/v10/types" - "github.com/go-pg/zerochecker" -) - -const ( - beforeScanHookFlag = uint16(1) << iota - afterScanHookFlag - afterSelectHookFlag - beforeInsertHookFlag - afterInsertHookFlag - beforeUpdateHookFlag - afterUpdateHookFlag - beforeDeleteHookFlag - afterDeleteHookFlag - discardUnknownColumnsFlag -) - -var ( - timeType = reflect.TypeOf((*time.Time)(nil)).Elem() - nullTimeType = reflect.TypeOf((*types.NullTime)(nil)).Elem() - sqlNullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem() - ipType = reflect.TypeOf((*net.IP)(nil)).Elem() - ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() - scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() - nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() - nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() - nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() - nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() - jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() -) - -var tableNameInflector = inflection.Plural - -// SetTableNameInflector overrides the default func that pluralizes -// model name to get table name, e.g. my_article becomes my_articles. -func SetTableNameInflector(fn func(string) string) { - tableNameInflector = fn -} - -// Table represents a SQL table created from Go struct. -type Table struct { - Type reflect.Type - zeroStruct reflect.Value - - TypeName string - Alias types.Safe - ModelName string - - SQLName types.Safe - SQLNameForSelects types.Safe - - Tablespace types.Safe - - PartitionBy string - - allFields []*Field // read only - skippedFields []*Field - - Fields []*Field // PKs + DataFields - PKs []*Field - DataFields []*Field - fieldsMapMu sync.RWMutex - FieldsMap map[string]*Field - - Methods map[string]*Method - Relations map[string]*Relation - Unique map[string][]*Field - - SoftDeleteField *Field - SetSoftDeleteField func(fv reflect.Value) error - - flags uint16 -} - -func newTable(typ reflect.Type) *Table { - t := new(Table) - t.Type = typ - t.zeroStruct = reflect.New(t.Type).Elem() - t.TypeName = internal.ToExported(t.Type.Name()) - t.ModelName = internal.Underscore(t.Type.Name()) - tableName := tableNameInflector(t.ModelName) - t.setName(quoteIdent(tableName)) - t.Alias = quoteIdent(t.ModelName) - - typ = reflect.PtrTo(t.Type) - if typ.Implements(beforeScanHookType) { - t.setFlag(beforeScanHookFlag) - } - if typ.Implements(afterScanHookType) { - t.setFlag(afterScanHookFlag) - } - if typ.Implements(afterSelectHookType) { - t.setFlag(afterSelectHookFlag) - } - if typ.Implements(beforeInsertHookType) { - t.setFlag(beforeInsertHookFlag) - } - if typ.Implements(afterInsertHookType) { - t.setFlag(afterInsertHookFlag) - } - if typ.Implements(beforeUpdateHookType) { - t.setFlag(beforeUpdateHookFlag) - } - if typ.Implements(afterUpdateHookType) { - t.setFlag(afterUpdateHookFlag) - } - if typ.Implements(beforeDeleteHookType) { - t.setFlag(beforeDeleteHookFlag) - } - if typ.Implements(afterDeleteHookType) { - t.setFlag(afterDeleteHookFlag) - } - - return t -} - -func (t *Table) init1() { - t.initFields() - t.initMethods() -} - -func (t *Table) init2() { - t.initInlines() - t.initRelations() - t.skippedFields = nil -} - -func (t *Table) setName(name types.Safe) { - t.SQLName = name - t.SQLNameForSelects = name - if t.Alias == "" { - t.Alias = name - } -} - -func (t *Table) String() string { - return "model=" + t.TypeName -} - -func (t *Table) setFlag(flag uint16) { - t.flags |= flag -} - -func (t *Table) hasFlag(flag uint16) bool { - if t == nil { - return false - } - return t.flags&flag != 0 -} - -func (t *Table) checkPKs() error { - if len(t.PKs) == 0 { - return fmt.Errorf("pg: %s does not have primary keys", t) - } - return nil -} - -func (t *Table) mustSoftDelete() error { - if t.SoftDeleteField == nil { - return fmt.Errorf("pg: %s does not support soft deletes", t) - } - return nil -} - -func (t *Table) AddField(field *Field) { - t.Fields = append(t.Fields, field) - if field.hasFlag(PrimaryKeyFlag) { - t.PKs = append(t.PKs, field) - } else { - t.DataFields = append(t.DataFields, field) - } - t.FieldsMap[field.SQLName] = field -} - -func (t *Table) RemoveField(field *Field) { - t.Fields = removeField(t.Fields, field) - if field.hasFlag(PrimaryKeyFlag) { - t.PKs = removeField(t.PKs, field) - } else { - t.DataFields = removeField(t.DataFields, field) - } - delete(t.FieldsMap, field.SQLName) -} - -func removeField(fields []*Field, field *Field) []*Field { - for i, f := range fields { - if f == field { - fields = append(fields[:i], fields[i+1:]...) - } - } - return fields -} - -func (t *Table) getField(name string) *Field { - t.fieldsMapMu.RLock() - field := t.FieldsMap[name] - t.fieldsMapMu.RUnlock() - return field -} - -func (t *Table) HasField(name string) bool { - _, ok := t.FieldsMap[name] - return ok -} - -func (t *Table) GetField(name string) (*Field, error) { - field, ok := t.FieldsMap[name] - if !ok { - return nil, fmt.Errorf("pg: %s does not have column=%s", t, name) - } - return field, nil -} - -func (t *Table) AppendParam(b []byte, strct reflect.Value, name string) ([]byte, bool) { - field, ok := t.FieldsMap[name] - if ok { - b = field.AppendValue(b, strct, 1) - return b, true - } - - method, ok := t.Methods[name] - if ok { - b = method.AppendValue(b, strct.Addr(), 1) - return b, true - } - - return b, false -} - -func (t *Table) initFields() { - t.Fields = make([]*Field, 0, t.Type.NumField()) - t.FieldsMap = make(map[string]*Field, t.Type.NumField()) - t.addFields(t.Type, nil) -} - -func (t *Table) addFields(typ reflect.Type, baseIndex []int) { - for i := 0; i < typ.NumField(); i++ { - f := typ.Field(i) - - // Make a copy so slice is not shared between fields. - index := make([]int, len(baseIndex)) - copy(index, baseIndex) - - if f.Anonymous { - if f.Tag.Get("sql") == "-" || f.Tag.Get("pg") == "-" { - continue - } - - fieldType := indirectType(f.Type) - if fieldType.Kind() != reflect.Struct { - continue - } - t.addFields(fieldType, append(index, f.Index...)) - - pgTag := tagparser.Parse(f.Tag.Get("pg")) - if _, inherit := pgTag.Options["inherit"]; inherit { - embeddedTable := _tables.get(fieldType, true) - t.TypeName = embeddedTable.TypeName - t.SQLName = embeddedTable.SQLName - t.SQLNameForSelects = embeddedTable.SQLNameForSelects - t.Alias = embeddedTable.Alias - t.ModelName = embeddedTable.ModelName - } - - continue - } - - field := t.newField(f, index) - if field != nil { - t.AddField(field) - } - } -} - -//nolint -func (t *Table) newField(f reflect.StructField, index []int) *Field { - pgTag := tagparser.Parse(f.Tag.Get("pg")) - - switch f.Name { - case "tableName": - if len(index) > 0 { - return nil - } - - if isKnownTableOption(pgTag.Name) { - internal.Warn.Printf( - "%s.%s tag name %q is also an option name; is it a mistake?", - t.TypeName, f.Name, pgTag.Name, - ) - } - - for name := range pgTag.Options { - if !isKnownTableOption(name) { - internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) - } - } - - if tableSpace, ok := pgTag.Options["tablespace"]; ok { - s, _ := tagparser.Unquote(tableSpace) - t.Tablespace = quoteIdent(s) - } - - partitionBy, ok := pgTag.Options["partition_by"] - if !ok { - partitionBy, ok = pgTag.Options["partitionBy"] - if ok { - internal.Deprecated.Printf("partitionBy is renamed to partition_by") - } - } - if ok { - s, _ := tagparser.Unquote(partitionBy) - t.PartitionBy = s - } - - if pgTag.Name == "_" { - t.setName("") - } else if pgTag.Name != "" { - s, _ := tagparser.Unquote(pgTag.Name) - t.setName(types.Safe(quoteTableName(s))) - } - - if s, ok := pgTag.Options["select"]; ok { - s, _ = tagparser.Unquote(s) - t.SQLNameForSelects = types.Safe(quoteTableName(s)) - } - - if v, ok := pgTag.Options["alias"]; ok { - v, _ = tagparser.Unquote(v) - t.Alias = quoteIdent(v) - } - - pgTag := tagparser.Parse(f.Tag.Get("pg")) - if _, ok := pgTag.Options["discard_unknown_columns"]; ok { - t.setFlag(discardUnknownColumnsFlag) - } - - return nil - } - - if f.PkgPath != "" { - return nil - } - - sqlName := internal.Underscore(f.Name) - - if pgTag.Name != sqlName && isKnownFieldOption(pgTag.Name) { - internal.Warn.Printf( - "%s.%s tag name %q is also an option name; is it a mistake?", - t.TypeName, f.Name, pgTag.Name, - ) - } - - for name := range pgTag.Options { - if !isKnownFieldOption(name) { - internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) - } - } - - skip := pgTag.Name == "-" - if !skip && pgTag.Name != "" { - sqlName = pgTag.Name - } - - index = append(index, f.Index...) - if field := t.getField(sqlName); field != nil { - if indexEqual(field.Index, index) { - return field - } - t.RemoveField(field) - } - - field := &Field{ - Field: f, - Type: indirectType(f.Type), - - GoName: f.Name, - SQLName: sqlName, - Column: quoteIdent(sqlName), - - Index: index, - } - - if _, ok := pgTag.Options["notnull"]; ok { - field.setFlag(NotNullFlag) - } - if v, ok := pgTag.Options["unique"]; ok { - if v == "" { - field.setFlag(UniqueFlag) - } - // Split the value by comma, this will allow multiple names to be specified. - // We can use this to create multiple named unique constraints where a single column - // might be included in multiple constraints. - v, _ = tagparser.Unquote(v) - for _, uniqueName := range strings.Split(v, ",") { - if t.Unique == nil { - t.Unique = make(map[string][]*Field) - } - t.Unique[uniqueName] = append(t.Unique[uniqueName], field) - } - } - if v, ok := pgTag.Options["default"]; ok { - v, ok = tagparser.Unquote(v) - if ok { - field.Default = types.Safe(types.AppendString(nil, v, 1)) - } else { - field.Default = types.Safe(v) - } - } - - //nolint - if _, ok := pgTag.Options["pk"]; ok { - field.setFlag(PrimaryKeyFlag) - } else if strings.HasSuffix(field.SQLName, "_id") || - strings.HasSuffix(field.SQLName, "_uuid") { - field.setFlag(ForeignKeyFlag) - } else if strings.HasPrefix(field.SQLName, "fk_") { - field.setFlag(ForeignKeyFlag) - } else if len(t.PKs) == 0 && !pgTag.HasOption("nopk") { - switch field.SQLName { - case "id", "uuid", "pk_" + t.ModelName: - field.setFlag(PrimaryKeyFlag) - } - } - - if _, ok := pgTag.Options["use_zero"]; ok { - field.setFlag(UseZeroFlag) - } - if _, ok := pgTag.Options["array"]; ok { - field.setFlag(ArrayFlag) - } - - field.SQLType = fieldSQLType(field, pgTag) - if strings.HasSuffix(field.SQLType, "[]") { - field.setFlag(ArrayFlag) - } - - if v, ok := pgTag.Options["on_delete"]; ok { - field.OnDelete = v - } - - if v, ok := pgTag.Options["on_update"]; ok { - field.OnUpdate = v - } - - if _, ok := pgTag.Options["composite"]; ok { - field.append = compositeAppender(f.Type) - field.scan = compositeScanner(f.Type) - } else if _, ok := pgTag.Options["json_use_number"]; ok { - field.append = types.Appender(f.Type) - field.scan = scanJSONValue - } else if field.hasFlag(ArrayFlag) { - field.append = types.ArrayAppender(f.Type) - field.scan = types.ArrayScanner(f.Type) - } else if _, ok := pgTag.Options["hstore"]; ok { - field.append = types.HstoreAppender(f.Type) - field.scan = types.HstoreScanner(f.Type) - } else if field.SQLType == pgTypeBigint && field.Type.Kind() == reflect.Uint64 { - if f.Type.Kind() == reflect.Ptr { - field.append = appendUintPtrAsInt - } else { - field.append = appendUintAsInt - } - field.scan = types.Scanner(f.Type) - } else if _, ok := pgTag.Options["msgpack"]; ok { - field.append = msgpackAppender(f.Type) - field.scan = msgpackScanner(f.Type) - } else { - field.append = types.Appender(f.Type) - field.scan = types.Scanner(f.Type) - } - field.isZero = zerochecker.Checker(f.Type) - - if v, ok := pgTag.Options["alias"]; ok { - v, _ = tagparser.Unquote(v) - t.FieldsMap[v] = field - } - - t.allFields = append(t.allFields, field) - if skip { - t.skippedFields = append(t.skippedFields, field) - t.FieldsMap[field.SQLName] = field - return nil - } - - if _, ok := pgTag.Options["soft_delete"]; ok { - t.SetSoftDeleteField = setSoftDeleteFieldFunc(f.Type) - if t.SetSoftDeleteField == nil { - err := fmt.Errorf( - "pg: soft_delete is only supported for time.Time, pg.NullTime, sql.NullInt64, and int64 (or implement ValueScanner that scans time)") - panic(err) - } - t.SoftDeleteField = field - } - - return field -} - -func (t *Table) initMethods() { - t.Methods = make(map[string]*Method) - typ := reflect.PtrTo(t.Type) - for i := 0; i < typ.NumMethod(); i++ { - m := typ.Method(i) - if m.PkgPath != "" { - continue - } - if m.Type.NumIn() > 1 { - continue - } - if m.Type.NumOut() != 1 { - continue - } - - retType := m.Type.Out(0) - t.Methods[m.Name] = &Method{ - Index: m.Index, - - appender: types.Appender(retType), - } - } -} - -func (t *Table) initInlines() { - for _, f := range t.skippedFields { - if f.Type.Kind() == reflect.Struct { - t.inlineFields(f, nil) - } - } -} - -func (t *Table) initRelations() { - for i := 0; i < len(t.Fields); { - f := t.Fields[i] - if t.tryRelation(f) { - t.Fields = removeField(t.Fields, f) - t.DataFields = removeField(t.DataFields, f) - } else { - i++ - } - - if f.Type.Kind() == reflect.Struct { - t.inlineFields(f, nil) - } - } -} - -func (t *Table) tryRelation(field *Field) bool { - pgTag := tagparser.Parse(field.Field.Tag.Get("pg")) - - if rel, ok := pgTag.Options["rel"]; ok { - return t.tryRelationType(field, rel, pgTag) - } - if _, ok := pgTag.Options["many2many"]; ok { - return t.tryRelationType(field, "many2many", pgTag) - } - - if field.UserSQLType != "" || isScanner(field.Type) { - return false - } - - switch field.Type.Kind() { - case reflect.Slice: - return t.tryRelationSlice(field, pgTag) - case reflect.Struct: - return t.tryRelationStruct(field, pgTag) - } - return false -} - -func (t *Table) tryRelationType(field *Field, rel string, pgTag *tagparser.Tag) bool { - switch rel { - case "has-one": - return t.mustHasOneRelation(field, pgTag) - case "belongs-to": - return t.mustBelongsToRelation(field, pgTag) - case "has-many": - return t.mustHasManyRelation(field, pgTag) - case "many2many": - return t.mustM2MRelation(field, pgTag) - default: - panic(fmt.Errorf("pg: unknown relation=%s on field=%s", rel, field.GoName)) - } -} - -func (t *Table) mustHasOneRelation(field *Field, pgTag *tagparser.Tag) bool { - joinTable := _tables.get(field.Type, true) - if err := joinTable.checkPKs(); err != nil { - panic(err) - } - fkPrefix, fkOK := pgTag.Options["fk"] - - if fkOK && len(joinTable.PKs) == 1 { - fk := t.getField(fkPrefix) - if fk == nil { - panic(fmt.Errorf( - "pg: %s has-one %s: %s must have column %s "+ - "(use fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, t.TypeName, fkPrefix, field.GoName, - )) - } - - t.addRelation(&Relation{ - Type: HasOneRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: []*Field{fk}, - JoinFKs: joinTable.PKs, - }) - return true - } - - if !fkOK { - fkPrefix = internal.Underscore(field.GoName) + "_" - } - fks := make([]*Field, 0, len(joinTable.PKs)) - - for _, joinPK := range joinTable.PKs { - fkName := fkPrefix + joinPK.SQLName - if fk := t.getField(fkName); fk != nil { - fks = append(fks, fk) - continue - } - - if fk := t.getField(joinPK.SQLName); fk != nil { - fks = append(fks, fk) - continue - } - - panic(fmt.Errorf( - "pg: %s has-one %s: %s must have column %s "+ - "(use fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, t.TypeName, fkName, field.GoName, - )) - } - - t.addRelation(&Relation{ - Type: HasOneRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: fks, - JoinFKs: joinTable.PKs, - }) - return true -} - -func (t *Table) mustBelongsToRelation(field *Field, pgTag *tagparser.Tag) bool { - if err := t.checkPKs(); err != nil { - panic(err) - } - joinTable := _tables.get(field.Type, true) - fkPrefix, fkOK := pgTag.Options["join_fk"] - - if fkOK && len(t.PKs) == 1 { - fk := joinTable.getField(fkPrefix) - if fk == nil { - panic(fmt.Errorf( - "pg: %s belongs-to %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - field.GoName, t.TypeName, joinTable.TypeName, fkPrefix, field.GoName, - )) - } - - t.addRelation(&Relation{ - Type: BelongsToRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: t.PKs, - JoinFKs: []*Field{fk}, - }) - return true - } - - if !fkOK { - fkPrefix = internal.Underscore(t.ModelName) + "_" - } - fks := make([]*Field, 0, len(t.PKs)) - - for _, pk := range t.PKs { - fkName := fkPrefix + pk.SQLName - if fk := joinTable.getField(fkName); fk != nil { - fks = append(fks, fk) - continue - } - - if fk := joinTable.getField(pk.SQLName); fk != nil { - fks = append(fks, fk) - continue - } - - panic(fmt.Errorf( - "pg: %s belongs-to %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName, - )) - } - - t.addRelation(&Relation{ - Type: BelongsToRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: t.PKs, - JoinFKs: fks, - }) - return true -} - -func (t *Table) mustHasManyRelation(field *Field, pgTag *tagparser.Tag) bool { - if err := t.checkPKs(); err != nil { - panic(err) - } - if field.Type.Kind() != reflect.Slice { - panic(fmt.Errorf( - "pg: %s.%s has-many relation requires slice, got %q", - t.TypeName, field.GoName, field.Type.Kind(), - )) - } - - joinTable := _tables.get(indirectType(field.Type.Elem()), true) - fkPrefix, fkOK := pgTag.Options["join_fk"] - _, polymorphic := pgTag.Options["polymorphic"] - - if fkOK && !polymorphic && len(t.PKs) == 1 { - fk := joinTable.getField(fkPrefix) - if fk == nil { - panic(fmt.Errorf( - "pg: %s has-many %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, joinTable.TypeName, fkPrefix, field.GoName, - )) - } - - t.addRelation(&Relation{ - Type: HasManyRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: t.PKs, - JoinFKs: []*Field{fk}, - }) - return true - } - - if !fkOK { - fkPrefix = internal.Underscore(t.ModelName) + "_" - } - fks := make([]*Field, 0, len(t.PKs)) - - for _, pk := range t.PKs { - fkName := fkPrefix + pk.SQLName - if fk := joinTable.getField(fkName); fk != nil { - fks = append(fks, fk) - continue - } - - if fk := joinTable.getField(pk.SQLName); fk != nil { - fks = append(fks, fk) - continue - } - - panic(fmt.Errorf( - "pg: %s has-many %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, joinTable.TypeName, fkName, field.GoName, - )) - } - - var typeField *Field - - if polymorphic { - typeFieldName := fkPrefix + "type" - typeField = joinTable.getField(typeFieldName) - if typeField == nil { - panic(fmt.Errorf( - "pg: %s has-many %s: %s must have polymorphic column %s", - t.TypeName, field.GoName, joinTable.TypeName, typeFieldName, - )) - } - } - - t.addRelation(&Relation{ - Type: HasManyRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: t.PKs, - JoinFKs: fks, - Polymorphic: typeField, - }) - return true -} - -func (t *Table) mustM2MRelation(field *Field, pgTag *tagparser.Tag) bool { - if field.Type.Kind() != reflect.Slice { - panic(fmt.Errorf( - "pg: %s.%s many2many relation requires slice, got %q", - t.TypeName, field.GoName, field.Type.Kind(), - )) - } - joinTable := _tables.get(indirectType(field.Type.Elem()), true) - - if err := t.checkPKs(); err != nil { - panic(err) - } - if err := joinTable.checkPKs(); err != nil { - panic(err) - } - - m2mTableNameString, ok := pgTag.Options["many2many"] - if !ok { - panic(fmt.Errorf("pg: %s must have many2many tag option", field.GoName)) - } - m2mTableName := quoteTableName(m2mTableNameString) - - m2mTable := _tables.getByName(m2mTableName) - if m2mTable == nil { - panic(fmt.Errorf( - "pg: can't find %s table (use orm.RegisterTable to register the model)", - m2mTableName, - )) - } - - var baseFKs []string - var joinFKs []string - - { - fkPrefix, ok := pgTag.Options["fk"] - if !ok { - fkPrefix = internal.Underscore(t.ModelName) + "_" - } - - if ok && len(t.PKs) == 1 { - if m2mTable.getField(fkPrefix) == nil { - panic(fmt.Errorf( - "pg: %s many2many %s: %s must have column %s "+ - "(use fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, m2mTable.TypeName, fkPrefix, field.GoName, - )) - } - baseFKs = []string{fkPrefix} - } else { - for _, pk := range t.PKs { - fkName := fkPrefix + pk.SQLName - if m2mTable.getField(fkName) != nil { - baseFKs = append(baseFKs, fkName) - continue - } - - if m2mTable.getField(pk.SQLName) != nil { - baseFKs = append(baseFKs, pk.SQLName) - continue - } - - panic(fmt.Errorf( - "pg: %s many2many %s: %s must have column %s "+ - "(use fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, m2mTable.TypeName, fkName, field.GoName, - )) - } - } - } - - { - joinFKPrefix, ok := pgTag.Options["join_fk"] - if !ok { - joinFKPrefix = internal.Underscore(joinTable.ModelName) + "_" - } - - if ok && len(joinTable.PKs) == 1 { - if m2mTable.getField(joinFKPrefix) == nil { - panic(fmt.Errorf( - "pg: %s many2many %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - joinTable.TypeName, field.GoName, m2mTable.TypeName, joinFKPrefix, field.GoName, - )) - } - joinFKs = []string{joinFKPrefix} - } else { - for _, joinPK := range joinTable.PKs { - fkName := joinFKPrefix + joinPK.SQLName - if m2mTable.getField(fkName) != nil { - joinFKs = append(joinFKs, fkName) - continue - } - - if m2mTable.getField(joinPK.SQLName) != nil { - joinFKs = append(joinFKs, joinPK.SQLName) - continue - } - - panic(fmt.Errorf( - "pg: %s many2many %s: %s must have column %s "+ - "(use join_fk:custom_column tag on %s field to specify custom column)", - t.TypeName, field.GoName, m2mTable.TypeName, fkName, field.GoName, - )) - } - } - } - - t.addRelation(&Relation{ - Type: Many2ManyRelation, - Field: field, - JoinTable: joinTable, - M2MTableName: m2mTableName, - M2MTableAlias: m2mTable.Alias, - M2MBaseFKs: baseFKs, - M2MJoinFKs: joinFKs, - }) - return true -} - -//nolint -func (t *Table) tryRelationSlice(field *Field, pgTag *tagparser.Tag) bool { - if t.tryM2MRelation(field, pgTag) { - internal.Deprecated.Printf( - `add pg:"rel:many2many" to %s.%s field tag`, t.TypeName, field.GoName) - return true - } - if t.tryHasManyRelation(field, pgTag) { - internal.Deprecated.Printf( - `add pg:"rel:has-many" to %s.%s field tag`, t.TypeName, field.GoName) - return true - } - return false -} - -func (t *Table) tryM2MRelation(field *Field, pgTag *tagparser.Tag) bool { - elemType := indirectType(field.Type.Elem()) - if elemType.Kind() != reflect.Struct { - return false - } - - joinTable := _tables.get(elemType, true) - - fk, fkOK := pgTag.Options["fk"] - if fkOK { - if fk == "-" { - return false - } - fk = tryUnderscorePrefix(fk) - } - - m2mTableName := pgTag.Options["many2many"] - if m2mTableName == "" { - return false - } - - m2mTable := _tables.getByName(quoteIdent(m2mTableName)) - - var m2mTableAlias types.Safe - if m2mTable != nil { - m2mTableAlias = m2mTable.Alias - } else if ind := strings.IndexByte(m2mTableName, '.'); ind >= 0 { - m2mTableAlias = quoteIdent(m2mTableName[ind+1:]) - } else { - m2mTableAlias = quoteIdent(m2mTableName) - } - - var fks []string - if !fkOK { - fk = t.ModelName + "_" - } - if m2mTable != nil { - keys := foreignKeys(t, m2mTable, fk, fkOK) - if len(keys) == 0 { - return false - } - for _, fk := range keys { - fks = append(fks, fk.SQLName) - } - } else { - if fkOK && len(t.PKs) == 1 { - fks = append(fks, fk) - } else { - for _, pk := range t.PKs { - fks = append(fks, fk+pk.SQLName) - } - } - } - - joinFK, joinFKOk := pgTag.Options["join_fk"] - if !joinFKOk { - joinFK, joinFKOk = pgTag.Options["joinFK"] - if joinFKOk { - internal.Deprecated.Printf("joinFK is renamed to join_fk") - } - } - if joinFKOk { - joinFK = tryUnderscorePrefix(joinFK) - } else { - joinFK = joinTable.ModelName + "_" - } - - var joinFKs []string - if m2mTable != nil { - keys := foreignKeys(joinTable, m2mTable, joinFK, joinFKOk) - if len(keys) == 0 { - return false - } - for _, fk := range keys { - joinFKs = append(joinFKs, fk.SQLName) - } - } else { - if joinFKOk && len(joinTable.PKs) == 1 { - joinFKs = append(joinFKs, joinFK) - } else { - for _, pk := range joinTable.PKs { - joinFKs = append(joinFKs, joinFK+pk.SQLName) - } - } - } - - t.addRelation(&Relation{ - Type: Many2ManyRelation, - Field: field, - JoinTable: joinTable, - M2MTableName: quoteIdent(m2mTableName), - M2MTableAlias: m2mTableAlias, - M2MBaseFKs: fks, - M2MJoinFKs: joinFKs, - }) - return true -} - -func (t *Table) tryHasManyRelation(field *Field, pgTag *tagparser.Tag) bool { - elemType := indirectType(field.Type.Elem()) - if elemType.Kind() != reflect.Struct { - return false - } - - joinTable := _tables.get(elemType, true) - - fk, fkOK := pgTag.Options["fk"] - if fkOK { - if fk == "-" { - return false - } - fk = tryUnderscorePrefix(fk) - } - - s, polymorphic := pgTag.Options["polymorphic"] - var typeField *Field - if polymorphic { - fk = tryUnderscorePrefix(s) - - typeField = joinTable.getField(fk + "type") - if typeField == nil { - return false - } - } else if !fkOK { - fk = t.ModelName + "_" - } - - fks := foreignKeys(t, joinTable, fk, fkOK || polymorphic) - if len(fks) == 0 { - return false - } - - var fkValues []*Field - fkValue, ok := pgTag.Options["fk_value"] - if ok { - if len(fks) > 1 { - panic(fmt.Errorf("got fk_value, but there are %d fks", len(fks))) - } - - f := t.getField(fkValue) - if f == nil { - panic(fmt.Errorf("fk_value=%q not found in %s", fkValue, t)) - } - fkValues = append(fkValues, f) - } else { - fkValues = t.PKs - } - - if len(fks) != len(fkValues) { - panic("len(fks) != len(fkValues)") - } - - if len(fks) > 0 { - t.addRelation(&Relation{ - Type: HasManyRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: fkValues, - JoinFKs: fks, - Polymorphic: typeField, - }) - return true - } - - return false -} - -func (t *Table) tryRelationStruct(field *Field, pgTag *tagparser.Tag) bool { - joinTable := _tables.get(field.Type, true) - - if len(joinTable.allFields) == 0 { - return false - } - - if t.tryHasOne(joinTable, field, pgTag) { - internal.Deprecated.Printf( - `add pg:"rel:has-one" to %s.%s field tag`, t.TypeName, field.GoName) - t.inlineFields(field, nil) - return true - } - - if t.tryBelongsToOne(joinTable, field, pgTag) { - internal.Deprecated.Printf( - `add pg:"rel:belongs-to" to %s.%s field tag`, t.TypeName, field.GoName) - t.inlineFields(field, nil) - return true - } - - t.inlineFields(field, nil) - return false -} - -func (t *Table) inlineFields(strct *Field, path map[reflect.Type]struct{}) { - if path == nil { - path = map[reflect.Type]struct{}{ - t.Type: {}, - } - } - - if _, ok := path[strct.Type]; ok { - return - } - path[strct.Type] = struct{}{} - - joinTable := _tables.get(strct.Type, true) - for _, f := range joinTable.allFields { - f = f.Clone() - f.GoName = strct.GoName + "_" + f.GoName - f.SQLName = strct.SQLName + "__" + f.SQLName - f.Column = quoteIdent(f.SQLName) - f.Index = appendNew(strct.Index, f.Index...) - - t.fieldsMapMu.Lock() - if _, ok := t.FieldsMap[f.SQLName]; !ok { - t.FieldsMap[f.SQLName] = f - } - t.fieldsMapMu.Unlock() - - if f.Type.Kind() != reflect.Struct { - continue - } - - if _, ok := path[f.Type]; !ok { - t.inlineFields(f, path) - } - } -} - -func appendNew(dst []int, src ...int) []int { - cp := make([]int, len(dst)+len(src)) - copy(cp, dst) - copy(cp[len(dst):], src) - return cp -} - -func isScanner(typ reflect.Type) bool { - return typ.Implements(scannerType) || reflect.PtrTo(typ).Implements(scannerType) -} - -func fieldSQLType(field *Field, pgTag *tagparser.Tag) string { - if typ, ok := pgTag.Options["type"]; ok { - typ, _ = tagparser.Unquote(typ) - field.UserSQLType = typ - typ = normalizeSQLType(typ) - return typ - } - - if typ, ok := pgTag.Options["composite"]; ok { - typ, _ = tagparser.Unquote(typ) - return typ - } - - if _, ok := pgTag.Options["hstore"]; ok { - return "hstore" - } else if _, ok := pgTag.Options["hstore"]; ok { - return "hstore" - } - - if field.hasFlag(ArrayFlag) { - switch field.Type.Kind() { - case reflect.Slice, reflect.Array: - sqlType := sqlType(field.Type.Elem()) - return sqlType + "[]" - } - } - - sqlType := sqlType(field.Type) - return sqlType -} - -func sqlType(typ reflect.Type) string { - switch typ { - case timeType, nullTimeType, sqlNullTimeType: - return pgTypeTimestampTz - case ipType: - return pgTypeInet - case ipNetType: - return pgTypeCidr - case nullBoolType: - return pgTypeBoolean - case nullFloatType: - return pgTypeDoublePrecision - case nullIntType: - return pgTypeBigint - case nullStringType: - return pgTypeText - case jsonRawMessageType: - return pgTypeJSONB - } - - switch typ.Kind() { - case reflect.Int8, reflect.Uint8, reflect.Int16: - return pgTypeSmallint - case reflect.Uint16, reflect.Int32: - return pgTypeInteger - case reflect.Uint32, reflect.Int64, reflect.Int: - return pgTypeBigint - case reflect.Uint, reflect.Uint64: - // Unsigned bigint is not supported - use bigint. - return pgTypeBigint - case reflect.Float32: - return pgTypeReal - case reflect.Float64: - return pgTypeDoublePrecision - case reflect.Bool: - return pgTypeBoolean - case reflect.String: - return pgTypeText - case reflect.Map, reflect.Struct: - return pgTypeJSONB - case reflect.Array, reflect.Slice: - if typ.Elem().Kind() == reflect.Uint8 { - return pgTypeBytea - } - return pgTypeJSONB - default: - return typ.Kind().String() - } -} - -func normalizeSQLType(s string) string { - switch s { - case "int2": - return pgTypeSmallint - case "int4", "int", "serial": - return pgTypeInteger - case "int8", pgTypeBigserial: - return pgTypeBigint - case "float4": - return pgTypeReal - case "float8": - return pgTypeDoublePrecision - } - return s -} - -func sqlTypeEqual(a, b string) bool { - return a == b -} - -func (t *Table) tryHasOne(joinTable *Table, field *Field, pgTag *tagparser.Tag) bool { - fk, fkOK := pgTag.Options["fk"] - if fkOK { - if fk == "-" { - return false - } - fk = tryUnderscorePrefix(fk) - } else { - fk = internal.Underscore(field.GoName) + "_" - } - - fks := foreignKeys(joinTable, t, fk, fkOK) - if len(fks) > 0 { - t.addRelation(&Relation{ - Type: HasOneRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: fks, - JoinFKs: joinTable.PKs, - }) - return true - } - return false -} - -func (t *Table) tryBelongsToOne(joinTable *Table, field *Field, pgTag *tagparser.Tag) bool { - fk, fkOK := pgTag.Options["fk"] - if fkOK { - if fk == "-" { - return false - } - fk = tryUnderscorePrefix(fk) - } else { - fk = internal.Underscore(t.TypeName) + "_" - } - - fks := foreignKeys(t, joinTable, fk, fkOK) - if len(fks) > 0 { - t.addRelation(&Relation{ - Type: BelongsToRelation, - Field: field, - JoinTable: joinTable, - BaseFKs: t.PKs, - JoinFKs: fks, - }) - return true - } - return false -} - -func (t *Table) addRelation(rel *Relation) { - if t.Relations == nil { - t.Relations = make(map[string]*Relation) - } - _, ok := t.Relations[rel.Field.GoName] - if ok { - panic(fmt.Errorf("%s already has %s", t, rel)) - } - t.Relations[rel.Field.GoName] = rel -} - -func foreignKeys(base, join *Table, fk string, tryFK bool) []*Field { - var fks []*Field - - for _, pk := range base.PKs { - fkName := fk + pk.SQLName - f := join.getField(fkName) - if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { - fks = append(fks, f) - continue - } - - if strings.IndexByte(pk.SQLName, '_') == -1 { - continue - } - - f = join.getField(pk.SQLName) - if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { - fks = append(fks, f) - continue - } - } - if len(fks) > 0 && len(fks) == len(base.PKs) { - return fks - } - - fks = nil - for _, pk := range base.PKs { - if !strings.HasPrefix(pk.SQLName, "pk_") { - continue - } - fkName := "fk_" + pk.SQLName[3:] - f := join.getField(fkName) - if f != nil && sqlTypeEqual(pk.SQLType, f.SQLType) { - fks = append(fks, f) - } - } - if len(fks) > 0 && len(fks) == len(base.PKs) { - return fks - } - - if fk == "" || len(base.PKs) != 1 { - return nil - } - - if tryFK { - f := join.getField(fk) - if f != nil && sqlTypeEqual(base.PKs[0].SQLType, f.SQLType) { - return []*Field{f} - } - } - - for _, suffix := range []string{"id", "uuid"} { - f := join.getField(fk + suffix) - if f != nil && sqlTypeEqual(base.PKs[0].SQLType, f.SQLType) { - return []*Field{f} - } - } - - return nil -} - -func scanJSONValue(v reflect.Value, rd types.Reader, n int) error { - // Zero value so it works with SelectOrInsert. - // TODO: better handle slices - v.Set(reflect.New(v.Type()).Elem()) - - if n == -1 { - return nil - } - - dec := pgjson.NewDecoder(rd) - dec.UseNumber() - return dec.Decode(v.Addr().Interface()) -} - -func appendUintAsInt(b []byte, v reflect.Value, _ int) []byte { - return strconv.AppendInt(b, int64(v.Uint()), 10) -} - -func appendUintPtrAsInt(b []byte, v reflect.Value, _ int) []byte { - return strconv.AppendInt(b, int64(v.Elem().Uint()), 10) -} - -func tryUnderscorePrefix(s string) string { - if s == "" { - return s - } - if c := s[0]; internal.IsUpper(c) { - return internal.Underscore(s) + "_" - } - return s -} - -func quoteTableName(s string) types.Safe { - // Don't quote if table name contains placeholder (?) or parentheses. - if strings.IndexByte(s, '?') >= 0 || - strings.IndexByte(s, '(') >= 0 && strings.IndexByte(s, ')') >= 0 { - return types.Safe(s) - } - return quoteIdent(s) -} - -func quoteIdent(s string) types.Safe { - return types.Safe(types.AppendIdent(nil, s, 1)) -} - -func setSoftDeleteFieldFunc(typ reflect.Type) func(fv reflect.Value) error { - switch typ { - case timeType: - return func(fv reflect.Value) error { - ptr := fv.Addr().Interface().(*time.Time) - *ptr = time.Now() - return nil - } - case nullTimeType: - return func(fv reflect.Value) error { - ptr := fv.Addr().Interface().(*types.NullTime) - *ptr = types.NullTime{Time: time.Now()} - return nil - } - case nullIntType: - return func(fv reflect.Value) error { - ptr := fv.Addr().Interface().(*sql.NullInt64) - *ptr = sql.NullInt64{Int64: time.Now().UnixNano()} - return nil - } - } - - switch typ.Kind() { - case reflect.Int64: - return func(fv reflect.Value) error { - ptr := fv.Addr().Interface().(*int64) - *ptr = time.Now().UnixNano() - return nil - } - case reflect.Ptr: - break - default: - return setSoftDeleteFallbackFunc(typ) - } - - originalType := typ - typ = typ.Elem() - - switch typ { //nolint:gocritic - case timeType: - return func(fv reflect.Value) error { - now := time.Now() - fv.Set(reflect.ValueOf(&now)) - return nil - } - } - - switch typ.Kind() { //nolint:gocritic - case reflect.Int64: - return func(fv reflect.Value) error { - utime := time.Now().UnixNano() - fv.Set(reflect.ValueOf(&utime)) - return nil - } - } - - return setSoftDeleteFallbackFunc(originalType) -} - -func setSoftDeleteFallbackFunc(typ reflect.Type) func(fv reflect.Value) error { - scanner := types.Scanner(typ) - if scanner == nil { - return nil - } - - return func(fv reflect.Value) error { - var flags int - b := types.AppendTime(nil, time.Now(), flags) - return scanner(fv, pool.NewBytesReader(b), len(b)) - } -} - -func isKnownTableOption(name string) bool { - switch name { - case "alias", - "select", - "tablespace", - "partition_by", - "discard_unknown_columns": - return true - } - return false -} - -func isKnownFieldOption(name string) bool { - switch name { - case "alias", - "type", - "array", - "hstore", - "composite", - "json_use_number", - "msgpack", - "notnull", - "use_zero", - "default", - "unique", - "soft_delete", - "on_delete", - "on_update", - - "pk", - "nopk", - "rel", - "fk", - "join_fk", - "many2many", - "polymorphic": - return true - } - return false -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table_create.go b/vendor/github.com/go-pg/pg/v10/orm/table_create.go deleted file mode 100644 index 384c729de..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/table_create.go +++ /dev/null @@ -1,248 +0,0 @@ -package orm - -import ( - "sort" - "strconv" - - "github.com/go-pg/pg/v10/types" -) - -type CreateTableOptions struct { - Varchar int // replaces PostgreSQL data type `text` with `varchar(n)` - Temp bool - IfNotExists bool - - // FKConstraints causes CreateTable to create foreign key constraints - // for has one relations. ON DELETE hook can be added using tag - // `pg:"on_delete:RESTRICT"` on foreign key field. ON UPDATE hook can be added using tag - // `pg:"on_update:CASCADE"` - FKConstraints bool -} - -type CreateTableQuery struct { - q *Query - opt *CreateTableOptions -} - -var ( - _ QueryAppender = (*CreateTableQuery)(nil) - _ QueryCommand = (*CreateTableQuery)(nil) -) - -func NewCreateTableQuery(q *Query, opt *CreateTableOptions) *CreateTableQuery { - return &CreateTableQuery{ - q: q, - opt: opt, - } -} - -func (q *CreateTableQuery) String() string { - b, err := q.AppendQuery(defaultFmter, nil) - if err != nil { - panic(err) - } - return string(b) -} - -func (q *CreateTableQuery) Operation() QueryOp { - return CreateTableOp -} - -func (q *CreateTableQuery) Clone() QueryCommand { - return &CreateTableQuery{ - q: q.q.Clone(), - opt: q.opt, - } -} - -func (q *CreateTableQuery) Query() *Query { - return q.q -} - -func (q *CreateTableQuery) AppendTemplate(b []byte) ([]byte, error) { - return q.AppendQuery(dummyFormatter{}, b) -} - -func (q *CreateTableQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { - if q.q.stickyErr != nil { - return nil, q.q.stickyErr - } - if q.q.tableModel == nil { - return nil, errModelNil - } - - table := q.q.tableModel.Table() - - b = append(b, "CREATE "...) - if q.opt != nil && q.opt.Temp { - b = append(b, "TEMP "...) - } - b = append(b, "TABLE "...) - if q.opt != nil && q.opt.IfNotExists { - b = append(b, "IF NOT EXISTS "...) - } - b, err = q.q.appendFirstTable(fmter, b) - if err != nil { - return nil, err - } - b = append(b, " ("...) - - for i, field := range table.Fields { - if i > 0 { - b = append(b, ", "...) - } - - b = append(b, field.Column...) - b = append(b, " "...) - b = q.appendSQLType(b, field) - if field.hasFlag(NotNullFlag) { - b = append(b, " NOT NULL"...) - } - if field.hasFlag(UniqueFlag) { - b = append(b, " UNIQUE"...) - } - if field.Default != "" { - b = append(b, " DEFAULT "...) - b = append(b, field.Default...) - } - } - - b = appendPKConstraint(b, table.PKs) - b = appendUniqueConstraints(b, table) - - if q.opt != nil && q.opt.FKConstraints { - for _, rel := range table.Relations { - b = q.appendFKConstraint(fmter, b, rel) - } - } - - b = append(b, ")"...) - - if table.PartitionBy != "" { - b = append(b, " PARTITION BY "...) - b = append(b, table.PartitionBy...) - } - - if table.Tablespace != "" { - b = q.appendTablespace(b, table.Tablespace) - } - - return b, q.q.stickyErr -} - -func (q *CreateTableQuery) appendSQLType(b []byte, field *Field) []byte { - if field.UserSQLType != "" { - return append(b, field.UserSQLType...) - } - if q.opt != nil && q.opt.Varchar > 0 && - field.SQLType == "text" { - b = append(b, "varchar("...) - b = strconv.AppendInt(b, int64(q.opt.Varchar), 10) - b = append(b, ")"...) - return b - } - if field.hasFlag(PrimaryKeyFlag) { - return append(b, pkSQLType(field.SQLType)...) - } - return append(b, field.SQLType...) -} - -func pkSQLType(s string) string { - switch s { - case pgTypeSmallint: - return pgTypeSmallserial - case pgTypeInteger: - return pgTypeSerial - case pgTypeBigint: - return pgTypeBigserial - } - return s -} - -func appendPKConstraint(b []byte, pks []*Field) []byte { - if len(pks) == 0 { - return b - } - - b = append(b, ", PRIMARY KEY ("...) - b = appendColumns(b, "", pks) - b = append(b, ")"...) - return b -} - -func appendUniqueConstraints(b []byte, table *Table) []byte { - keys := make([]string, 0, len(table.Unique)) - for key := range table.Unique { - keys = append(keys, key) - } - sort.Strings(keys) - - for _, key := range keys { - b = appendUnique(b, table.Unique[key]) - } - - return b -} - -func appendUnique(b []byte, fields []*Field) []byte { - b = append(b, ", UNIQUE ("...) - b = appendColumns(b, "", fields) - b = append(b, ")"...) - return b -} - -func (q *CreateTableQuery) appendFKConstraint(fmter QueryFormatter, b []byte, rel *Relation) []byte { - if rel.Type != HasOneRelation { - return b - } - - b = append(b, ", FOREIGN KEY ("...) - b = appendColumns(b, "", rel.BaseFKs) - b = append(b, ")"...) - - b = append(b, " REFERENCES "...) - b = fmter.FormatQuery(b, string(rel.JoinTable.SQLName)) - b = append(b, " ("...) - b = appendColumns(b, "", rel.JoinFKs) - b = append(b, ")"...) - - if s := onDelete(rel.BaseFKs); s != "" { - b = append(b, " ON DELETE "...) - b = append(b, s...) - } - - if s := onUpdate(rel.BaseFKs); s != "" { - b = append(b, " ON UPDATE "...) - b = append(b, s...) - } - - return b -} - -func (q *CreateTableQuery) appendTablespace(b []byte, tableSpace types.Safe) []byte { - b = append(b, " TABLESPACE "...) - b = append(b, tableSpace...) - return b -} - -func onDelete(fks []*Field) string { - var onDelete string - for _, f := range fks { - if f.OnDelete != "" { - onDelete = f.OnDelete - break - } - } - return onDelete -} - -func onUpdate(fks []*Field) string { - var onUpdate string - for _, f := range fks { - if f.OnUpdate != "" { - onUpdate = f.OnUpdate - break - } - } - return onUpdate -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table_drop.go b/vendor/github.com/go-pg/pg/v10/orm/table_drop.go deleted file mode 100644 index 599ac3952..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/table_drop.go +++ /dev/null @@ -1,73 +0,0 @@ -package orm - -type DropTableOptions struct { - IfExists bool - Cascade bool -} - -type DropTableQuery struct { - q *Query - opt *DropTableOptions -} - -var ( - _ QueryAppender = (*DropTableQuery)(nil) - _ QueryCommand = (*DropTableQuery)(nil) -) - -func NewDropTableQuery(q *Query, opt *DropTableOptions) *DropTableQuery { - return &DropTableQuery{ - q: q, - opt: opt, - } -} - -func (q *DropTableQuery) String() string { - b, err := q.AppendQuery(defaultFmter, nil) - if err != nil { - panic(err) - } - return string(b) -} - -func (q *DropTableQuery) Operation() QueryOp { - return DropTableOp -} - -func (q *DropTableQuery) Clone() QueryCommand { - return &DropTableQuery{ - q: q.q.Clone(), - opt: q.opt, - } -} - -func (q *DropTableQuery) Query() *Query { - return q.q -} - -func (q *DropTableQuery) AppendTemplate(b []byte) ([]byte, error) { - return q.AppendQuery(dummyFormatter{}, b) -} - -func (q *DropTableQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { - if q.q.stickyErr != nil { - return nil, q.q.stickyErr - } - if q.q.tableModel == nil { - return nil, errModelNil - } - - b = append(b, "DROP TABLE "...) - if q.opt != nil && q.opt.IfExists { - b = append(b, "IF EXISTS "...) - } - b, err = q.q.appendFirstTable(fmter, b) - if err != nil { - return nil, err - } - if q.opt != nil && q.opt.Cascade { - b = append(b, " CASCADE"...) - } - - return b, q.q.stickyErr -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/table_params.go b/vendor/github.com/go-pg/pg/v10/orm/table_params.go deleted file mode 100644 index 46d8e064a..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/table_params.go +++ /dev/null @@ -1,29 +0,0 @@ -package orm - -import "reflect" - -type tableParams struct { - table *Table - strct reflect.Value -} - -func newTableParams(strct interface{}) (*tableParams, bool) { - v := reflect.ValueOf(strct) - if !v.IsValid() { - return nil, false - } - - v = reflect.Indirect(v) - if v.Kind() != reflect.Struct { - return nil, false - } - - return &tableParams{ - table: GetTable(v.Type()), - strct: v, - }, true -} - -func (m *tableParams) AppendParam(fmter QueryFormatter, b []byte, name string) ([]byte, bool) { - return m.table.AppendParam(b, m.strct, name) -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/types.go b/vendor/github.com/go-pg/pg/v10/orm/types.go deleted file mode 100644 index c8e9ec375..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/types.go +++ /dev/null @@ -1,48 +0,0 @@ -package orm - -//nolint -const ( - // Date / Time - pgTypeTimestamp = "timestamp" // Timestamp without a time zone - pgTypeTimestampTz = "timestamptz" // Timestamp with a time zone - pgTypeDate = "date" // Date - pgTypeTime = "time" // Time without a time zone - pgTypeTimeTz = "time with time zone" // Time with a time zone - pgTypeInterval = "interval" // Time Interval - - // Network Addresses - pgTypeInet = "inet" // IPv4 or IPv6 hosts and networks - pgTypeCidr = "cidr" // IPv4 or IPv6 networks - pgTypeMacaddr = "macaddr" // MAC addresses - - // Boolean - pgTypeBoolean = "boolean" - - // Numeric Types - - // Floating Point Types - pgTypeReal = "real" // 4 byte floating point (6 digit precision) - pgTypeDoublePrecision = "double precision" // 8 byte floating point (15 digit precision) - - // Integer Types - pgTypeSmallint = "smallint" // 2 byte integer - pgTypeInteger = "integer" // 4 byte integer - pgTypeBigint = "bigint" // 8 byte integer - - // Serial Types - pgTypeSmallserial = "smallserial" // 2 byte autoincrementing integer - pgTypeSerial = "serial" // 4 byte autoincrementing integer - pgTypeBigserial = "bigserial" // 8 byte autoincrementing integer - - // Character Types - pgTypeVarchar = "varchar" // variable length string with limit - pgTypeChar = "char" // fixed length string (blank padded) - pgTypeText = "text" // variable length string without limit - - // JSON Types - pgTypeJSON = "json" // text representation of json data - pgTypeJSONB = "jsonb" // binary representation of json data - - // Binary Data Types - pgTypeBytea = "bytea" // binary string -) diff --git a/vendor/github.com/go-pg/pg/v10/orm/update.go b/vendor/github.com/go-pg/pg/v10/orm/update.go deleted file mode 100644 index ce6396fd3..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/update.go +++ /dev/null @@ -1,378 +0,0 @@ -package orm - -import ( - "fmt" - "reflect" - "sort" - - "github.com/go-pg/pg/v10/types" -) - -type UpdateQuery struct { - q *Query - omitZero bool - placeholder bool -} - -var ( - _ QueryAppender = (*UpdateQuery)(nil) - _ QueryCommand = (*UpdateQuery)(nil) -) - -func NewUpdateQuery(q *Query, omitZero bool) *UpdateQuery { - return &UpdateQuery{ - q: q, - omitZero: omitZero, - } -} - -func (q *UpdateQuery) String() string { - b, err := q.AppendQuery(defaultFmter, nil) - if err != nil { - panic(err) - } - return string(b) -} - -func (q *UpdateQuery) Operation() QueryOp { - return UpdateOp -} - -func (q *UpdateQuery) Clone() QueryCommand { - return &UpdateQuery{ - q: q.q.Clone(), - omitZero: q.omitZero, - placeholder: q.placeholder, - } -} - -func (q *UpdateQuery) Query() *Query { - return q.q -} - -func (q *UpdateQuery) AppendTemplate(b []byte) ([]byte, error) { - cp := q.Clone().(*UpdateQuery) - cp.placeholder = true - return cp.AppendQuery(dummyFormatter{}, b) -} - -func (q *UpdateQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { - if q.q.stickyErr != nil { - return nil, q.q.stickyErr - } - - if len(q.q.with) > 0 { - b, err = q.q.appendWith(fmter, b) - if err != nil { - return nil, err - } - } - - b = append(b, "UPDATE "...) - - b, err = q.q.appendFirstTableWithAlias(fmter, b) - if err != nil { - return nil, err - } - - b, err = q.mustAppendSet(fmter, b) - if err != nil { - return nil, err - } - - isSliceModelWithData := q.q.isSliceModelWithData() - if isSliceModelWithData || q.q.hasMultiTables() { - b = append(b, " FROM "...) - b, err = q.q.appendOtherTables(fmter, b) - if err != nil { - return nil, err - } - - if isSliceModelWithData { - b, err = q.appendSliceModelData(fmter, b) - if err != nil { - return nil, err - } - } - } - - b, err = q.mustAppendWhere(fmter, b, isSliceModelWithData) - if err != nil { - return nil, err - } - - if len(q.q.returning) > 0 { - b, err = q.q.appendReturning(fmter, b) - if err != nil { - return nil, err - } - } - - return b, q.q.stickyErr -} - -func (q *UpdateQuery) mustAppendWhere( - fmter QueryFormatter, b []byte, isSliceModelWithData bool, -) (_ []byte, err error) { - b = append(b, " WHERE "...) - - if !isSliceModelWithData { - return q.q.mustAppendWhere(fmter, b) - } - - if len(q.q.where) > 0 { - return q.q.appendWhere(fmter, b) - } - - table := q.q.tableModel.Table() - err = table.checkPKs() - if err != nil { - return nil, err - } - - b = appendWhereColumnAndColumn(b, table.Alias, table.PKs) - return b, nil -} - -func (q *UpdateQuery) mustAppendSet(fmter QueryFormatter, b []byte) (_ []byte, err error) { - if len(q.q.set) > 0 { - return q.q.appendSet(fmter, b) - } - - b = append(b, " SET "...) - - if m, ok := q.q.model.(*mapModel); ok { - return q.appendMapSet(b, m.m), nil - } - - if !q.q.hasTableModel() { - return nil, errModelNil - } - - value := q.q.tableModel.Value() - if value.Kind() == reflect.Struct { - b, err = q.appendSetStruct(fmter, b, value) - } else { - if value.Len() > 0 { - b, err = q.appendSetSlice(b) - } else { - err = fmt.Errorf("pg: can't bulk-update empty slice %s", value.Type()) - } - } - if err != nil { - return nil, err - } - - return b, nil -} - -func (q *UpdateQuery) appendMapSet(b []byte, m map[string]interface{}) []byte { - keys := make([]string, 0, len(m)) - - for k := range m { - keys = append(keys, k) - } - sort.Strings(keys) - - for i, k := range keys { - if i > 0 { - b = append(b, ", "...) - } - - b = types.AppendIdent(b, k, 1) - b = append(b, " = "...) - if q.placeholder { - b = append(b, '?') - } else { - b = types.Append(b, m[k], 1) - } - } - - return b -} - -func (q *UpdateQuery) appendSetStruct(fmter QueryFormatter, b []byte, strct reflect.Value) ([]byte, error) { - fields, err := q.q.getFields() - if err != nil { - return nil, err - } - - if len(fields) == 0 { - fields = q.q.tableModel.Table().DataFields - } - - pos := len(b) - for _, f := range fields { - if q.omitZero && f.NullZero() && f.HasZeroValue(strct) { - continue - } - - if len(b) != pos { - b = append(b, ", "...) - pos = len(b) - } - - b = append(b, f.Column...) - b = append(b, " = "...) - - if q.placeholder { - b = append(b, '?') - continue - } - - app, ok := q.q.modelValues[f.SQLName] - if ok { - b, err = app.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } else { - b = f.AppendValue(b, strct, 1) - } - } - - for i, v := range q.q.extraValues { - if i > 0 || len(fields) > 0 { - b = append(b, ", "...) - } - - b = append(b, v.column...) - b = append(b, " = "...) - - b, err = v.value.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - } - - return b, nil -} - -func (q *UpdateQuery) appendSetSlice(b []byte) ([]byte, error) { - fields, err := q.q.getFields() - if err != nil { - return nil, err - } - - if len(fields) == 0 { - fields = q.q.tableModel.Table().DataFields - } - - var table *Table - if q.omitZero { - table = q.q.tableModel.Table() - } - - for i, f := range fields { - if i > 0 { - b = append(b, ", "...) - } - - b = append(b, f.Column...) - b = append(b, " = "...) - if q.omitZero && table != nil { - b = append(b, "COALESCE("...) - } - b = append(b, "_data."...) - b = append(b, f.Column...) - if q.omitZero && table != nil { - b = append(b, ", "...) - if table.Alias != table.SQLName { - b = append(b, table.Alias...) - b = append(b, '.') - } - b = append(b, f.Column...) - b = append(b, ")"...) - } - } - - return b, nil -} - -func (q *UpdateQuery) appendSliceModelData(fmter QueryFormatter, b []byte) ([]byte, error) { - columns, err := q.q.getDataFields() - if err != nil { - return nil, err - } - - if len(columns) > 0 { - columns = append(columns, q.q.tableModel.Table().PKs...) - } else { - columns = q.q.tableModel.Table().Fields - } - - return q.appendSliceValues(fmter, b, columns, q.q.tableModel.Value()) -} - -func (q *UpdateQuery) appendSliceValues( - fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value, -) (_ []byte, err error) { - b = append(b, "(VALUES ("...) - - if q.placeholder { - b, err = q.appendValues(fmter, b, fields, reflect.Value{}) - if err != nil { - return nil, err - } - } else { - sliceLen := slice.Len() - for i := 0; i < sliceLen; i++ { - if i > 0 { - b = append(b, "), ("...) - } - b, err = q.appendValues(fmter, b, fields, slice.Index(i)) - if err != nil { - return nil, err - } - } - } - - b = append(b, ")) AS _data("...) - b = appendColumns(b, "", fields) - b = append(b, ")"...) - - return b, nil -} - -func (q *UpdateQuery) appendValues( - fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value, -) (_ []byte, err error) { - for i, f := range fields { - if i > 0 { - b = append(b, ", "...) - } - - app, ok := q.q.modelValues[f.SQLName] - if ok { - b, err = app.AppendQuery(fmter, b) - if err != nil { - return nil, err - } - continue - } - - if q.placeholder { - b = append(b, '?') - } else { - b = f.AppendValue(b, indirect(strct), 1) - } - - b = append(b, "::"...) - b = append(b, f.SQLType...) - } - return b, nil -} - -func appendWhereColumnAndColumn(b []byte, alias types.Safe, fields []*Field) []byte { - for i, f := range fields { - if i > 0 { - b = append(b, " AND "...) - } - b = append(b, alias...) - b = append(b, '.') - b = append(b, f.Column...) - b = append(b, " = _data."...) - b = append(b, f.Column...) - } - return b -} diff --git a/vendor/github.com/go-pg/pg/v10/orm/util.go b/vendor/github.com/go-pg/pg/v10/orm/util.go deleted file mode 100644 index b7963ba0b..000000000 --- a/vendor/github.com/go-pg/pg/v10/orm/util.go +++ /dev/null @@ -1,151 +0,0 @@ -package orm - -import ( - "reflect" - - "github.com/go-pg/pg/v10/types" -) - -func indirect(v reflect.Value) reflect.Value { - switch v.Kind() { - case reflect.Interface: - return indirect(v.Elem()) - case reflect.Ptr: - return v.Elem() - default: - return v - } -} - -func indirectType(t reflect.Type) reflect.Type { - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - return t -} - -func sliceElemType(v reflect.Value) reflect.Type { - elemType := v.Type().Elem() - if elemType.Kind() == reflect.Interface && v.Len() > 0 { - return indirect(v.Index(0).Elem()).Type() - } - return indirectType(elemType) -} - -func typeByIndex(t reflect.Type, index []int) reflect.Type { - for _, x := range index { - switch t.Kind() { - case reflect.Ptr: - t = t.Elem() - case reflect.Slice: - t = indirectType(t.Elem()) - } - t = t.Field(x).Type - } - return indirectType(t) -} - -func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) { - if len(index) == 1 { - return v.Field(index[0]), true - } - - for i, idx := range index { - if i > 0 { - if v.Kind() == reflect.Ptr { - if v.IsNil() { - return v, false - } - v = v.Elem() - } - } - v = v.Field(idx) - } - return v, true -} - -func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value { - if len(index) == 1 { - return v.Field(index[0]) - } - - for i, idx := range index { - if i > 0 { - v = indirectNil(v) - } - v = v.Field(idx) - } - return v -} - -func indirectNil(v reflect.Value) reflect.Value { - if v.Kind() == reflect.Ptr { - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - v = v.Elem() - } - return v -} - -func walk(v reflect.Value, index []int, fn func(reflect.Value)) { - v = reflect.Indirect(v) - switch v.Kind() { - case reflect.Slice: - sliceLen := v.Len() - for i := 0; i < sliceLen; i++ { - visitField(v.Index(i), index, fn) - } - default: - visitField(v, index, fn) - } -} - -func visitField(v reflect.Value, index []int, fn func(reflect.Value)) { - v = reflect.Indirect(v) - if len(index) > 0 { - v = v.Field(index[0]) - if v.Kind() == reflect.Ptr && v.IsNil() { - return - } - walk(v, index[1:], fn) - } else { - fn(v) - } -} - -func dstValues(model TableModel, fields []*Field) map[string][]reflect.Value { - fieldIndex := model.Relation().Field.Index - m := make(map[string][]reflect.Value) - var id []byte - walk(model.Root(), model.ParentIndex(), func(v reflect.Value) { - id = modelID(id[:0], v, fields) - m[string(id)] = append(m[string(id)], v.FieldByIndex(fieldIndex)) - }) - return m -} - -func modelID(b []byte, v reflect.Value, fields []*Field) []byte { - for i, f := range fields { - if i > 0 { - b = append(b, ',') - } - b = f.AppendValue(b, v, 0) - } - return b -} - -func appendColumns(b []byte, table types.Safe, fields []*Field) []byte { - for i, f := range fields { - if i > 0 { - b = append(b, ", "...) - } - - if len(table) > 0 { - b = append(b, table...) - b = append(b, '.') - } - b = append(b, f.Column...) - } - return b -} diff --git a/vendor/github.com/go-pg/pg/v10/pg.go b/vendor/github.com/go-pg/pg/v10/pg.go deleted file mode 100644 index 923ef6bef..000000000 --- a/vendor/github.com/go-pg/pg/v10/pg.go +++ /dev/null @@ -1,274 +0,0 @@ -package pg - -import ( - "context" - "io" - "strconv" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/orm" - "github.com/go-pg/pg/v10/types" -) - -// Discard is used with Query and QueryOne to discard rows. -var Discard orm.Discard - -// NullTime is a time.Time wrapper that marshals zero time as JSON null and -// PostgreSQL NULL. -type NullTime = types.NullTime - -// Scan returns ColumnScanner that copies the columns in the -// row into the values. -func Scan(values ...interface{}) orm.ColumnScanner { - return orm.Scan(values...) -} - -// Safe represents a safe SQL query. -type Safe = types.Safe - -// Ident represents a SQL identifier, e.g. table or column name. -type Ident = types.Ident - -// SafeQuery replaces any placeholders found in the query. -func SafeQuery(query string, params ...interface{}) *orm.SafeQueryAppender { - return orm.SafeQuery(query, params...) -} - -// In accepts a slice and returns a wrapper that can be used with PostgreSQL -// IN operator: -// -// Where("id IN (?)", pg.In([]int{1, 2, 3, 4})) -// -// produces -// -// WHERE id IN (1, 2, 3, 4) -func In(slice interface{}) types.ValueAppender { - return types.In(slice) -} - -// InMulti accepts multiple values and returns a wrapper that can be used -// with PostgreSQL IN operator: -// -// Where("(id1, id2) IN (?)", pg.InMulti([]int{1, 2}, []int{3, 4})) -// -// produces -// -// WHERE (id1, id2) IN ((1, 2), (3, 4)) -func InMulti(values ...interface{}) types.ValueAppender { - return types.InMulti(values...) -} - -// Array accepts a slice and returns a wrapper for working with PostgreSQL -// array data type. -// -// For struct fields you can use array tag: -// -// Emails []string `pg:",array"` -func Array(v interface{}) *types.Array { - return types.NewArray(v) -} - -// Hstore accepts a map and returns a wrapper for working with hstore data type. -// Supported map types are: -// - map[string]string -// -// For struct fields you can use hstore tag: -// -// Attrs map[string]string `pg:",hstore"` -func Hstore(v interface{}) *types.Hstore { - return types.NewHstore(v) -} - -// SetLogger sets the logger to the given one. -func SetLogger(logger internal.Logging) { - internal.Logger = logger -} - -//------------------------------------------------------------------------------ - -type Query = orm.Query - -// Model returns a new query for the optional model. -func Model(model ...interface{}) *Query { - return orm.NewQuery(nil, model...) -} - -// ModelContext returns a new query for the optional model with a context. -func ModelContext(c context.Context, model ...interface{}) *Query { - return orm.NewQueryContext(c, nil, model...) -} - -// DBI is a DB interface implemented by *DB and *Tx. -type DBI interface { - Model(model ...interface{}) *Query - ModelContext(c context.Context, model ...interface{}) *Query - - Exec(query interface{}, params ...interface{}) (Result, error) - ExecContext(c context.Context, query interface{}, params ...interface{}) (Result, error) - ExecOne(query interface{}, params ...interface{}) (Result, error) - ExecOneContext(c context.Context, query interface{}, params ...interface{}) (Result, error) - Query(model, query interface{}, params ...interface{}) (Result, error) - QueryContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) - QueryOne(model, query interface{}, params ...interface{}) (Result, error) - QueryOneContext(c context.Context, model, query interface{}, params ...interface{}) (Result, error) - - Begin() (*Tx, error) - RunInTransaction(ctx context.Context, fn func(*Tx) error) error - - CopyFrom(r io.Reader, query interface{}, params ...interface{}) (Result, error) - CopyTo(w io.Writer, query interface{}, params ...interface{}) (Result, error) -} - -var ( - _ DBI = (*DB)(nil) - _ DBI = (*Tx)(nil) -) - -//------------------------------------------------------------------------------ - -// Strings is a type alias for a slice of strings. -type Strings []string - -var ( - _ orm.HooklessModel = (*Strings)(nil) - _ types.ValueAppender = (*Strings)(nil) -) - -// Init initializes the Strings slice. -func (strings *Strings) Init() error { - if s := *strings; len(s) > 0 { - *strings = s[:0] - } - return nil -} - -// NextColumnScanner ... -func (strings *Strings) NextColumnScanner() orm.ColumnScanner { - return strings -} - -// AddColumnScanner ... -func (Strings) AddColumnScanner(_ orm.ColumnScanner) error { - return nil -} - -// ScanColumn scans the columns and appends them to `strings`. -func (strings *Strings) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { - b := make([]byte, n) - _, err := io.ReadFull(rd, b) - if err != nil { - return err - } - - *strings = append(*strings, internal.BytesToString(b)) - return nil -} - -// AppendValue appends the values from `strings` to the given byte slice. -func (strings Strings) AppendValue(dst []byte, quote int) ([]byte, error) { - if len(strings) == 0 { - return dst, nil - } - - for _, s := range strings { - dst = types.AppendString(dst, s, 1) - dst = append(dst, ',') - } - dst = dst[:len(dst)-1] - return dst, nil -} - -//------------------------------------------------------------------------------ - -// Ints is a type alias for a slice of int64 values. -type Ints []int64 - -var ( - _ orm.HooklessModel = (*Ints)(nil) - _ types.ValueAppender = (*Ints)(nil) -) - -// Init initializes the Int slice. -func (ints *Ints) Init() error { - if s := *ints; len(s) > 0 { - *ints = s[:0] - } - return nil -} - -// NewColumnScanner ... -func (ints *Ints) NextColumnScanner() orm.ColumnScanner { - return ints -} - -// AddColumnScanner ... -func (Ints) AddColumnScanner(_ orm.ColumnScanner) error { - return nil -} - -// ScanColumn scans the columns and appends them to `ints`. -func (ints *Ints) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { - num, err := types.ScanInt64(rd, n) - if err != nil { - return err - } - - *ints = append(*ints, num) - return nil -} - -// AppendValue appends the values from `ints` to the given byte slice. -func (ints Ints) AppendValue(dst []byte, quote int) ([]byte, error) { - if len(ints) == 0 { - return dst, nil - } - - for _, v := range ints { - dst = strconv.AppendInt(dst, v, 10) - dst = append(dst, ',') - } - dst = dst[:len(dst)-1] - return dst, nil -} - -//------------------------------------------------------------------------------ - -// IntSet is a set of int64 values. -type IntSet map[int64]struct{} - -var _ orm.HooklessModel = (*IntSet)(nil) - -// Init initializes the IntSet. -func (set *IntSet) Init() error { - if len(*set) > 0 { - *set = make(map[int64]struct{}) - } - return nil -} - -// NextColumnScanner ... -func (set *IntSet) NextColumnScanner() orm.ColumnScanner { - return set -} - -// AddColumnScanner ... -func (IntSet) AddColumnScanner(_ orm.ColumnScanner) error { - return nil -} - -// ScanColumn scans the columns and appends them to `IntSet`. -func (set *IntSet) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { - num, err := types.ScanInt64(rd, n) - if err != nil { - return err - } - - setVal := *set - if setVal == nil { - *set = make(IntSet) - setVal = *set - } - - setVal[num] = struct{}{} - return nil -} diff --git a/vendor/github.com/go-pg/pg/v10/result.go b/vendor/github.com/go-pg/pg/v10/result.go deleted file mode 100644 index b8d8d9e45..000000000 --- a/vendor/github.com/go-pg/pg/v10/result.go +++ /dev/null @@ -1,53 +0,0 @@ -package pg - -import ( - "bytes" - "strconv" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/orm" -) - -// Result summarizes an executed SQL command. -type Result = orm.Result - -// A result summarizes an executed SQL command. -type result struct { - model orm.Model - - affected int - returned int -} - -var _ Result = (*result)(nil) - -//nolint -func (res *result) parse(b []byte) error { - res.affected = -1 - - ind := bytes.LastIndexByte(b, ' ') - if ind == -1 { - return nil - } - - s := internal.BytesToString(b[ind+1 : len(b)-1]) - - affected, err := strconv.Atoi(s) - if err == nil { - res.affected = affected - } - - return nil -} - -func (res *result) Model() orm.Model { - return res.model -} - -func (res *result) RowsAffected() int { - return res.affected -} - -func (res *result) RowsReturned() int { - return res.returned -} diff --git a/vendor/github.com/go-pg/pg/v10/stmt.go b/vendor/github.com/go-pg/pg/v10/stmt.go deleted file mode 100644 index 528788379..000000000 --- a/vendor/github.com/go-pg/pg/v10/stmt.go +++ /dev/null @@ -1,282 +0,0 @@ -package pg - -import ( - "context" - "errors" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/internal/pool" - "github.com/go-pg/pg/v10/orm" - "github.com/go-pg/pg/v10/types" -) - -var errStmtClosed = errors.New("pg: statement is closed") - -// Stmt is a prepared statement. Stmt is safe for concurrent use by -// multiple goroutines. -type Stmt struct { - db *baseDB - stickyErr error - - q string - name string - columns []types.ColumnInfo -} - -func prepareStmt(db *baseDB, q string) (*Stmt, error) { - stmt := &Stmt{ - db: db, - - q: q, - } - - err := stmt.prepare(context.TODO(), q) - if err != nil { - _ = stmt.Close() - return nil, err - } - return stmt, nil -} - -func (stmt *Stmt) prepare(ctx context.Context, q string) error { - var lastErr error - for attempt := 0; attempt <= stmt.db.opt.MaxRetries; attempt++ { - if attempt > 0 { - if err := internal.Sleep(ctx, stmt.db.retryBackoff(attempt-1)); err != nil { - return err - } - - err := stmt.db.pool.(*pool.StickyConnPool).Reset(ctx) - if err != nil { - return err - } - } - - lastErr = stmt.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - var err error - stmt.name, stmt.columns, err = stmt.db.prepare(ctx, cn, q) - return err - }) - if !stmt.db.shouldRetry(lastErr) { - break - } - } - return lastErr -} - -func (stmt *Stmt) withConn(c context.Context, fn func(context.Context, *pool.Conn) error) error { - if stmt.stickyErr != nil { - return stmt.stickyErr - } - err := stmt.db.withConn(c, fn) - if err == pool.ErrClosed { - return errStmtClosed - } - return err -} - -// Exec executes a prepared statement with the given parameters. -func (stmt *Stmt) Exec(params ...interface{}) (Result, error) { - return stmt.exec(context.TODO(), params...) -} - -// ExecContext executes a prepared statement with the given parameters. -func (stmt *Stmt) ExecContext(c context.Context, params ...interface{}) (Result, error) { - return stmt.exec(c, params...) -} - -func (stmt *Stmt) exec(ctx context.Context, params ...interface{}) (Result, error) { - ctx, evt, err := stmt.db.beforeQuery(ctx, stmt.db.db, nil, stmt.q, params, nil) - if err != nil { - return nil, err - } - - var res Result - var lastErr error - for attempt := 0; attempt <= stmt.db.opt.MaxRetries; attempt++ { - if attempt > 0 { - lastErr = internal.Sleep(ctx, stmt.db.retryBackoff(attempt-1)) - if lastErr != nil { - break - } - } - - lastErr = stmt.withConn(ctx, func(c context.Context, cn *pool.Conn) error { - res, err = stmt.extQuery(ctx, cn, stmt.name, params...) - return err - }) - if !stmt.db.shouldRetry(lastErr) { - break - } - } - - if err := stmt.db.afterQuery(ctx, evt, res, lastErr); err != nil { - return nil, err - } - return res, lastErr -} - -// ExecOne acts like Exec, but query must affect only one row. It -// returns ErrNoRows error when query returns zero rows or -// ErrMultiRows when query returns multiple rows. -func (stmt *Stmt) ExecOne(params ...interface{}) (Result, error) { - return stmt.execOne(context.Background(), params...) -} - -// ExecOneContext acts like ExecOne but additionally receives a context. -func (stmt *Stmt) ExecOneContext(c context.Context, params ...interface{}) (Result, error) { - return stmt.execOne(c, params...) -} - -func (stmt *Stmt) execOne(c context.Context, params ...interface{}) (Result, error) { - res, err := stmt.ExecContext(c, params...) - if err != nil { - return nil, err - } - - if err := internal.AssertOneRow(res.RowsAffected()); err != nil { - return nil, err - } - return res, nil -} - -// Query executes a prepared query statement with the given parameters. -func (stmt *Stmt) Query(model interface{}, params ...interface{}) (Result, error) { - return stmt.query(context.Background(), model, params...) -} - -// QueryContext acts like Query but additionally receives a context. -func (stmt *Stmt) QueryContext(c context.Context, model interface{}, params ...interface{}) (Result, error) { - return stmt.query(c, model, params...) -} - -func (stmt *Stmt) query(ctx context.Context, model interface{}, params ...interface{}) (Result, error) { - ctx, evt, err := stmt.db.beforeQuery(ctx, stmt.db.db, model, stmt.q, params, nil) - if err != nil { - return nil, err - } - - var res Result - var lastErr error - for attempt := 0; attempt <= stmt.db.opt.MaxRetries; attempt++ { - if attempt > 0 { - lastErr = internal.Sleep(ctx, stmt.db.retryBackoff(attempt-1)) - if lastErr != nil { - break - } - } - - lastErr = stmt.withConn(ctx, func(c context.Context, cn *pool.Conn) error { - res, err = stmt.extQueryData(ctx, cn, stmt.name, model, stmt.columns, params...) - return err - }) - if !stmt.db.shouldRetry(lastErr) { - break - } - } - - if err := stmt.db.afterQuery(ctx, evt, res, lastErr); err != nil { - return nil, err - } - return res, lastErr -} - -// QueryOne acts like Query, but query must return only one row. It -// returns ErrNoRows error when query returns zero rows or -// ErrMultiRows when query returns multiple rows. -func (stmt *Stmt) QueryOne(model interface{}, params ...interface{}) (Result, error) { - return stmt.queryOne(context.Background(), model, params...) -} - -// QueryOneContext acts like QueryOne but additionally receives a context. -func (stmt *Stmt) QueryOneContext(c context.Context, model interface{}, params ...interface{}) (Result, error) { - return stmt.queryOne(c, model, params...) -} - -func (stmt *Stmt) queryOne(c context.Context, model interface{}, params ...interface{}) (Result, error) { - mod, err := orm.NewModel(model) - if err != nil { - return nil, err - } - - res, err := stmt.QueryContext(c, mod, params...) - if err != nil { - return nil, err - } - - if err := internal.AssertOneRow(res.RowsAffected()); err != nil { - return nil, err - } - return res, nil -} - -// Close closes the statement. -func (stmt *Stmt) Close() error { - var firstErr error - - if stmt.name != "" { - firstErr = stmt.closeStmt() - } - - err := stmt.db.Close() - if err != nil && firstErr == nil { - firstErr = err - } - - return firstErr -} - -func (stmt *Stmt) extQuery( - c context.Context, cn *pool.Conn, name string, params ...interface{}, -) (Result, error) { - err := cn.WithWriter(c, stmt.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - return writeBindExecuteMsg(wb, name, params...) - }) - if err != nil { - return nil, err - } - - var res Result - err = cn.WithReader(c, stmt.db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { - res, err = readExtQuery(rd) - return err - }) - if err != nil { - return nil, err - } - - return res, nil -} - -func (stmt *Stmt) extQueryData( - c context.Context, - cn *pool.Conn, - name string, - model interface{}, - columns []types.ColumnInfo, - params ...interface{}, -) (Result, error) { - err := cn.WithWriter(c, stmt.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { - return writeBindExecuteMsg(wb, name, params...) - }) - if err != nil { - return nil, err - } - - var res *result - err = cn.WithReader(c, stmt.db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { - res, err = readExtQueryData(c, rd, model, columns) - return err - }) - if err != nil { - return nil, err - } - - return res, nil -} - -func (stmt *Stmt) closeStmt() error { - return stmt.withConn(context.TODO(), func(c context.Context, cn *pool.Conn) error { - return stmt.db.closeStmt(c, cn, stmt.name) - }) -} diff --git a/vendor/github.com/go-pg/pg/v10/tx.go b/vendor/github.com/go-pg/pg/v10/tx.go deleted file mode 100644 index db444ff65..000000000 --- a/vendor/github.com/go-pg/pg/v10/tx.go +++ /dev/null @@ -1,388 +0,0 @@ -package pg - -import ( - "context" - "errors" - "io" - "sync" - "sync/atomic" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/internal/pool" - "github.com/go-pg/pg/v10/orm" -) - -// ErrTxDone is returned by any operation that is performed on a transaction -// that has already been committed or rolled back. -var ErrTxDone = errors.New("pg: transaction has already been committed or rolled back") - -// Tx is an in-progress database transaction. It is safe for concurrent use -// by multiple goroutines. -// -// A transaction must end with a call to Commit or Rollback. -// -// After a call to Commit or Rollback, all operations on the transaction fail -// with ErrTxDone. -// -// The statements prepared for a transaction by calling the transaction's -// Prepare or Stmt methods are closed by the call to Commit or Rollback. -type Tx struct { - db *baseDB - ctx context.Context - - stmtsMu sync.Mutex - stmts []*Stmt - - _closed int32 -} - -var _ orm.DB = (*Tx)(nil) - -// Context returns the context.Context of the transaction. -func (tx *Tx) Context() context.Context { - return tx.ctx -} - -// Begin starts a transaction. Most callers should use RunInTransaction instead. -func (db *baseDB) Begin() (*Tx, error) { - return db.BeginContext(db.db.Context()) -} - -func (db *baseDB) BeginContext(ctx context.Context) (*Tx, error) { - tx := &Tx{ - db: db.withPool(pool.NewStickyConnPool(db.pool)), - ctx: ctx, - } - - err := tx.begin(ctx) - if err != nil { - tx.close() - return nil, err - } - - return tx, nil -} - -// RunInTransaction runs a function in a transaction. If function -// returns an error transaction is rolled back, otherwise transaction -// is committed. -func (db *baseDB) RunInTransaction(ctx context.Context, fn func(*Tx) error) error { - tx, err := db.BeginContext(ctx) - if err != nil { - return err - } - return tx.RunInTransaction(ctx, fn) -} - -// Begin returns current transaction. It does not start new transaction. -func (tx *Tx) Begin() (*Tx, error) { - return tx, nil -} - -// RunInTransaction runs a function in the transaction. If function -// returns an error transaction is rolled back, otherwise transaction -// is committed. -func (tx *Tx) RunInTransaction(ctx context.Context, fn func(*Tx) error) error { - defer func() { - if err := recover(); err != nil { - if err := tx.RollbackContext(ctx); err != nil { - internal.Logger.Printf(ctx, "tx.Rollback panicked: %s", err) - } - panic(err) - } - }() - - if err := fn(tx); err != nil { - if err := tx.RollbackContext(ctx); err != nil { - internal.Logger.Printf(ctx, "tx.Rollback failed: %s", err) - } - return err - } - return tx.CommitContext(ctx) -} - -func (tx *Tx) withConn(c context.Context, fn func(context.Context, *pool.Conn) error) error { - err := tx.db.withConn(c, fn) - if tx.closed() && err == pool.ErrClosed { - return ErrTxDone - } - return err -} - -// Stmt returns a transaction-specific prepared statement -// from an existing statement. -func (tx *Tx) Stmt(stmt *Stmt) *Stmt { - stmt, err := tx.Prepare(stmt.q) - if err != nil { - return &Stmt{stickyErr: err} - } - return stmt -} - -// Prepare creates a prepared statement for use within a transaction. -// -// The returned statement operates within the transaction and can no longer -// be used once the transaction has been committed or rolled back. -// -// To use an existing prepared statement on this transaction, see Tx.Stmt. -func (tx *Tx) Prepare(q string) (*Stmt, error) { - tx.stmtsMu.Lock() - defer tx.stmtsMu.Unlock() - - db := tx.db.withPool(pool.NewStickyConnPool(tx.db.pool)) - stmt, err := prepareStmt(db, q) - if err != nil { - return nil, err - } - tx.stmts = append(tx.stmts, stmt) - - return stmt, nil -} - -// Exec is an alias for DB.Exec. -func (tx *Tx) Exec(query interface{}, params ...interface{}) (Result, error) { - return tx.exec(tx.ctx, query, params...) -} - -// ExecContext acts like Exec but additionally receives a context. -func (tx *Tx) ExecContext(c context.Context, query interface{}, params ...interface{}) (Result, error) { - return tx.exec(c, query, params...) -} - -func (tx *Tx) exec(ctx context.Context, query interface{}, params ...interface{}) (Result, error) { - wb := pool.GetWriteBuffer() - defer pool.PutWriteBuffer(wb) - - if err := writeQueryMsg(wb, tx.db.fmter, query, params...); err != nil { - return nil, err - } - - ctx, evt, err := tx.db.beforeQuery(ctx, tx, nil, query, params, wb.Query()) - if err != nil { - return nil, err - } - - var res Result - lastErr := tx.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - res, err = tx.db.simpleQuery(ctx, cn, wb) - return err - }) - - if err := tx.db.afterQuery(ctx, evt, res, lastErr); err != nil { - return nil, err - } - return res, lastErr -} - -// ExecOne is an alias for DB.ExecOne. -func (tx *Tx) ExecOne(query interface{}, params ...interface{}) (Result, error) { - return tx.execOne(tx.ctx, query, params...) -} - -// ExecOneContext acts like ExecOne but additionally receives a context. -func (tx *Tx) ExecOneContext(c context.Context, query interface{}, params ...interface{}) (Result, error) { - return tx.execOne(c, query, params...) -} - -func (tx *Tx) execOne(c context.Context, query interface{}, params ...interface{}) (Result, error) { - res, err := tx.ExecContext(c, query, params...) - if err != nil { - return nil, err - } - - if err := internal.AssertOneRow(res.RowsAffected()); err != nil { - return nil, err - } - return res, nil -} - -// Query is an alias for DB.Query. -func (tx *Tx) Query(model interface{}, query interface{}, params ...interface{}) (Result, error) { - return tx.query(tx.ctx, model, query, params...) -} - -// QueryContext acts like Query but additionally receives a context. -func (tx *Tx) QueryContext( - c context.Context, - model interface{}, - query interface{}, - params ...interface{}, -) (Result, error) { - return tx.query(c, model, query, params...) -} - -func (tx *Tx) query( - ctx context.Context, - model interface{}, - query interface{}, - params ...interface{}, -) (Result, error) { - wb := pool.GetWriteBuffer() - defer pool.PutWriteBuffer(wb) - - if err := writeQueryMsg(wb, tx.db.fmter, query, params...); err != nil { - return nil, err - } - - ctx, evt, err := tx.db.beforeQuery(ctx, tx, model, query, params, wb.Query()) - if err != nil { - return nil, err - } - - var res *result - lastErr := tx.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - res, err = tx.db.simpleQueryData(ctx, cn, model, wb) - return err - }) - - if err := tx.db.afterQuery(ctx, evt, res, err); err != nil { - return nil, err - } - return res, lastErr -} - -// QueryOne is an alias for DB.QueryOne. -func (tx *Tx) QueryOne(model interface{}, query interface{}, params ...interface{}) (Result, error) { - return tx.queryOne(tx.ctx, model, query, params...) -} - -// QueryOneContext acts like QueryOne but additionally receives a context. -func (tx *Tx) QueryOneContext( - c context.Context, - model interface{}, - query interface{}, - params ...interface{}, -) (Result, error) { - return tx.queryOne(c, model, query, params...) -} - -func (tx *Tx) queryOne( - c context.Context, - model interface{}, - query interface{}, - params ...interface{}, -) (Result, error) { - mod, err := orm.NewModel(model) - if err != nil { - return nil, err - } - - res, err := tx.QueryContext(c, mod, query, params...) - if err != nil { - return nil, err - } - - if err := internal.AssertOneRow(res.RowsAffected()); err != nil { - return nil, err - } - return res, nil -} - -// Model is an alias for DB.Model. -func (tx *Tx) Model(model ...interface{}) *Query { - return orm.NewQuery(tx, model...) -} - -// ModelContext acts like Model but additionally receives a context. -func (tx *Tx) ModelContext(c context.Context, model ...interface{}) *Query { - return orm.NewQueryContext(c, tx, model...) -} - -// CopyFrom is an alias for DB.CopyFrom. -func (tx *Tx) CopyFrom(r io.Reader, query interface{}, params ...interface{}) (res Result, err error) { - err = tx.withConn(tx.ctx, func(c context.Context, cn *pool.Conn) error { - res, err = tx.db.copyFrom(c, cn, r, query, params...) - return err - }) - return res, err -} - -// CopyTo is an alias for DB.CopyTo. -func (tx *Tx) CopyTo(w io.Writer, query interface{}, params ...interface{}) (res Result, err error) { - err = tx.withConn(tx.ctx, func(c context.Context, cn *pool.Conn) error { - res, err = tx.db.copyTo(c, cn, w, query, params...) - return err - }) - return res, err -} - -// Formatter is an alias for DB.Formatter. -func (tx *Tx) Formatter() orm.QueryFormatter { - return tx.db.Formatter() -} - -func (tx *Tx) begin(ctx context.Context) error { - var lastErr error - for attempt := 0; attempt <= tx.db.opt.MaxRetries; attempt++ { - if attempt > 0 { - if err := internal.Sleep(ctx, tx.db.retryBackoff(attempt-1)); err != nil { - return err - } - - err := tx.db.pool.(*pool.StickyConnPool).Reset(ctx) - if err != nil { - return err - } - } - - _, lastErr = tx.ExecContext(ctx, "BEGIN") - if !tx.db.shouldRetry(lastErr) { - break - } - } - return lastErr -} - -func (tx *Tx) Commit() error { - return tx.CommitContext(tx.ctx) -} - -// Commit commits the transaction. -func (tx *Tx) CommitContext(ctx context.Context) error { - _, err := tx.ExecContext(internal.UndoContext(ctx), "COMMIT") - tx.close() - return err -} - -func (tx *Tx) Rollback() error { - return tx.RollbackContext(tx.ctx) -} - -// Rollback aborts the transaction. -func (tx *Tx) RollbackContext(ctx context.Context) error { - _, err := tx.ExecContext(internal.UndoContext(ctx), "ROLLBACK") - tx.close() - return err -} - -func (tx *Tx) Close() error { - return tx.CloseContext(tx.ctx) -} - -// Close calls Rollback if the tx has not already been committed or rolled back. -func (tx *Tx) CloseContext(ctx context.Context) error { - if tx.closed() { - return nil - } - return tx.RollbackContext(ctx) -} - -func (tx *Tx) close() { - if !atomic.CompareAndSwapInt32(&tx._closed, 0, 1) { - return - } - - tx.stmtsMu.Lock() - defer tx.stmtsMu.Unlock() - - for _, stmt := range tx.stmts { - _ = stmt.Close() - } - tx.stmts = nil - - _ = tx.db.Close() -} - -func (tx *Tx) closed() bool { - return atomic.LoadInt32(&tx._closed) == 1 -} diff --git a/vendor/github.com/go-pg/pg/v10/types/append.go b/vendor/github.com/go-pg/pg/v10/types/append.go deleted file mode 100644 index 05be2a0fa..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/append.go +++ /dev/null @@ -1,201 +0,0 @@ -package types - -import ( - "math" - "reflect" - "strconv" - "time" - "unicode/utf8" - - "github.com/tmthrgd/go-hex" -) - -func Append(b []byte, v interface{}, flags int) []byte { - switch v := v.(type) { - case nil: - return AppendNull(b, flags) - case bool: - return appendBool(b, v) - case int32: - return strconv.AppendInt(b, int64(v), 10) - case int64: - return strconv.AppendInt(b, v, 10) - case int: - return strconv.AppendInt(b, int64(v), 10) - case float32: - return appendFloat(b, float64(v), flags, 32) - case float64: - return appendFloat(b, v, flags, 64) - case string: - return AppendString(b, v, flags) - case time.Time: - return AppendTime(b, v, flags) - case []byte: - return AppendBytes(b, v, flags) - case ValueAppender: - return appendAppender(b, v, flags) - default: - return appendValue(b, reflect.ValueOf(v), flags) - } -} - -func AppendError(b []byte, err error) []byte { - b = append(b, "?!("...) - b = append(b, err.Error()...) - b = append(b, ')') - return b -} - -func AppendNull(b []byte, flags int) []byte { - if hasFlag(flags, quoteFlag) { - return append(b, "NULL"...) - } - return nil -} - -func appendBool(dst []byte, v bool) []byte { - if v { - return append(dst, "TRUE"...) - } - return append(dst, "FALSE"...) -} - -func appendFloat(dst []byte, v float64, flags int, bitSize int) []byte { - if hasFlag(flags, arrayFlag) { - return appendFloat2(dst, v, flags) - } - - switch { - case math.IsNaN(v): - if hasFlag(flags, quoteFlag) { - return append(dst, "'NaN'"...) - } - return append(dst, "NaN"...) - case math.IsInf(v, 1): - if hasFlag(flags, quoteFlag) { - return append(dst, "'Infinity'"...) - } - return append(dst, "Infinity"...) - case math.IsInf(v, -1): - if hasFlag(flags, quoteFlag) { - return append(dst, "'-Infinity'"...) - } - return append(dst, "-Infinity"...) - default: - return strconv.AppendFloat(dst, v, 'f', -1, bitSize) - } -} - -func appendFloat2(dst []byte, v float64, _ int) []byte { - switch { - case math.IsNaN(v): - return append(dst, "NaN"...) - case math.IsInf(v, 1): - return append(dst, "Infinity"...) - case math.IsInf(v, -1): - return append(dst, "-Infinity"...) - default: - return strconv.AppendFloat(dst, v, 'f', -1, 64) - } -} - -func AppendString(b []byte, s string, flags int) []byte { - if hasFlag(flags, arrayFlag) { - return appendString2(b, s, flags) - } - - if hasFlag(flags, quoteFlag) { - b = append(b, '\'') - for _, c := range s { - if c == '\000' { - continue - } - - if c == '\'' { - b = append(b, '\'', '\'') - } else { - b = appendRune(b, c) - } - } - b = append(b, '\'') - return b - } - - for _, c := range s { - if c != '\000' { - b = appendRune(b, c) - } - } - return b -} - -func appendString2(b []byte, s string, flags int) []byte { - b = append(b, '"') - for _, c := range s { - if c == '\000' { - continue - } - - switch c { - case '\'': - if hasFlag(flags, quoteFlag) { - b = append(b, '\'') - } - b = append(b, '\'') - case '"': - b = append(b, '\\', '"') - case '\\': - b = append(b, '\\', '\\') - default: - b = appendRune(b, c) - } - } - b = append(b, '"') - return b -} - -func appendRune(b []byte, r rune) []byte { - if r < utf8.RuneSelf { - return append(b, byte(r)) - } - l := len(b) - if cap(b)-l < utf8.UTFMax { - b = append(b, make([]byte, utf8.UTFMax)...) - } - n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) - return b[:l+n] -} - -func AppendBytes(b []byte, bytes []byte, flags int) []byte { - if bytes == nil { - return AppendNull(b, flags) - } - - if hasFlag(flags, arrayFlag) { - b = append(b, `"\`...) - } else if hasFlag(flags, quoteFlag) { - b = append(b, '\'') - } - - b = append(b, `\x`...) - - s := len(b) - b = append(b, make([]byte, hex.EncodedLen(len(bytes)))...) - hex.Encode(b[s:], bytes) - - if hasFlag(flags, arrayFlag) { - b = append(b, '"') - } else if hasFlag(flags, quoteFlag) { - b = append(b, '\'') - } - - return b -} - -func appendAppender(b []byte, v ValueAppender, flags int) []byte { - bb, err := v.AppendValue(b, flags) - if err != nil { - return AppendError(b, err) - } - return bb -} diff --git a/vendor/github.com/go-pg/pg/v10/types/append_ident.go b/vendor/github.com/go-pg/pg/v10/types/append_ident.go deleted file mode 100644 index 60b9d6784..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/append_ident.go +++ /dev/null @@ -1,46 +0,0 @@ -package types - -import "github.com/go-pg/pg/v10/internal" - -func AppendIdent(b []byte, field string, flags int) []byte { - return appendIdent(b, internal.StringToBytes(field), flags) -} - -func AppendIdentBytes(b []byte, field []byte, flags int) []byte { - return appendIdent(b, field, flags) -} - -func appendIdent(b, src []byte, flags int) []byte { - var quoted bool -loop: - for _, c := range src { - switch c { - case '*': - if !quoted { - b = append(b, '*') - continue loop - } - case '.': - if quoted && hasFlag(flags, quoteFlag) { - b = append(b, '"') - quoted = false - } - b = append(b, '.') - continue loop - } - - if !quoted && hasFlag(flags, quoteFlag) { - b = append(b, '"') - quoted = true - } - if c == '"' { - b = append(b, '"', '"') - } else { - b = append(b, c) - } - } - if quoted && hasFlag(flags, quoteFlag) { - b = append(b, '"') - } - return b -} diff --git a/vendor/github.com/go-pg/pg/v10/types/append_jsonb.go b/vendor/github.com/go-pg/pg/v10/types/append_jsonb.go deleted file mode 100644 index ffe221825..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/append_jsonb.go +++ /dev/null @@ -1,49 +0,0 @@ -package types - -import "github.com/go-pg/pg/v10/internal/parser" - -func AppendJSONB(b, jsonb []byte, flags int) []byte { - if hasFlag(flags, arrayFlag) { - b = append(b, '"') - } else if hasFlag(flags, quoteFlag) { - b = append(b, '\'') - } - - p := parser.New(jsonb) - for p.Valid() { - c := p.Read() - switch c { - case '"': - if hasFlag(flags, arrayFlag) { - b = append(b, '\\') - } - b = append(b, '"') - case '\'': - if hasFlag(flags, quoteFlag) { - b = append(b, '\'') - } - b = append(b, '\'') - case '\000': - continue - case '\\': - if p.SkipBytes([]byte("u0000")) { - b = append(b, "\\\\u0000"...) - } else { - b = append(b, '\\') - if p.Valid() { - b = append(b, p.Read()) - } - } - default: - b = append(b, c) - } - } - - if hasFlag(flags, arrayFlag) { - b = append(b, '"') - } else if hasFlag(flags, quoteFlag) { - b = append(b, '\'') - } - - return b -} diff --git a/vendor/github.com/go-pg/pg/v10/types/append_value.go b/vendor/github.com/go-pg/pg/v10/types/append_value.go deleted file mode 100644 index f12fc564f..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/append_value.go +++ /dev/null @@ -1,248 +0,0 @@ -package types - -import ( - "database/sql/driver" - "fmt" - "net" - "reflect" - "strconv" - "sync" - "time" - - "github.com/vmihailenco/bufpool" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/pgjson" -) - -var ( - driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() - appenderType = reflect.TypeOf((*ValueAppender)(nil)).Elem() -) - -type AppenderFunc func([]byte, reflect.Value, int) []byte - -var appenders []AppenderFunc - -//nolint -func init() { - appenders = []AppenderFunc{ - reflect.Bool: appendBoolValue, - reflect.Int: appendIntValue, - reflect.Int8: appendIntValue, - reflect.Int16: appendIntValue, - reflect.Int32: appendIntValue, - reflect.Int64: appendIntValue, - reflect.Uint: appendUintValue, - reflect.Uint8: appendUintValue, - reflect.Uint16: appendUintValue, - reflect.Uint32: appendUintValue, - reflect.Uint64: appendUintValue, - reflect.Uintptr: nil, - reflect.Float32: appendFloat32Value, - reflect.Float64: appendFloat64Value, - reflect.Complex64: nil, - reflect.Complex128: nil, - reflect.Array: appendJSONValue, - reflect.Chan: nil, - reflect.Func: nil, - reflect.Interface: appendIfaceValue, - reflect.Map: appendJSONValue, - reflect.Ptr: nil, - reflect.Slice: appendJSONValue, - reflect.String: appendStringValue, - reflect.Struct: appendStructValue, - reflect.UnsafePointer: nil, - } -} - -var appendersMap sync.Map - -// RegisterAppender registers an appender func for the value type. -// Expecting to be used only during initialization, it panics -// if there is already a registered appender for the given type. -func RegisterAppender(value interface{}, fn AppenderFunc) { - registerAppender(reflect.TypeOf(value), fn) -} - -func registerAppender(typ reflect.Type, fn AppenderFunc) { - _, loaded := appendersMap.LoadOrStore(typ, fn) - if loaded { - err := fmt.Errorf("pg: appender for the type=%s is already registered", - typ.String()) - panic(err) - } -} - -func Appender(typ reflect.Type) AppenderFunc { - if v, ok := appendersMap.Load(typ); ok { - return v.(AppenderFunc) - } - fn := appender(typ, false) - _, _ = appendersMap.LoadOrStore(typ, fn) - return fn -} - -func appender(typ reflect.Type, pgArray bool) AppenderFunc { - switch typ { - case timeType: - return appendTimeValue - case ipType: - return appendIPValue - case ipNetType: - return appendIPNetValue - case jsonRawMessageType: - return appendJSONRawMessageValue - } - - if typ.Implements(appenderType) { - return appendAppenderValue - } - if typ.Implements(driverValuerType) { - return appendDriverValuerValue - } - - kind := typ.Kind() - switch kind { - case reflect.Ptr: - return ptrAppenderFunc(typ) - case reflect.Slice: - if typ.Elem().Kind() == reflect.Uint8 { - return appendBytesValue - } - if pgArray { - return ArrayAppender(typ) - } - case reflect.Array: - if typ.Elem().Kind() == reflect.Uint8 { - return appendArrayBytesValue - } - } - return appenders[kind] -} - -func ptrAppenderFunc(typ reflect.Type) AppenderFunc { - appender := Appender(typ.Elem()) - return func(b []byte, v reflect.Value, flags int) []byte { - if v.IsNil() { - return AppendNull(b, flags) - } - return appender(b, v.Elem(), flags) - } -} - -func appendValue(b []byte, v reflect.Value, flags int) []byte { - if v.Kind() == reflect.Ptr && v.IsNil() { - return AppendNull(b, flags) - } - appender := Appender(v.Type()) - return appender(b, v, flags) -} - -func appendIfaceValue(b []byte, v reflect.Value, flags int) []byte { - return Append(b, v.Interface(), flags) -} - -func appendBoolValue(b []byte, v reflect.Value, _ int) []byte { - return appendBool(b, v.Bool()) -} - -func appendIntValue(b []byte, v reflect.Value, _ int) []byte { - return strconv.AppendInt(b, v.Int(), 10) -} - -func appendUintValue(b []byte, v reflect.Value, _ int) []byte { - return strconv.AppendUint(b, v.Uint(), 10) -} - -func appendFloat32Value(b []byte, v reflect.Value, flags int) []byte { - return appendFloat(b, v.Float(), flags, 32) -} - -func appendFloat64Value(b []byte, v reflect.Value, flags int) []byte { - return appendFloat(b, v.Float(), flags, 64) -} - -func appendBytesValue(b []byte, v reflect.Value, flags int) []byte { - return AppendBytes(b, v.Bytes(), flags) -} - -func appendArrayBytesValue(b []byte, v reflect.Value, flags int) []byte { - if v.CanAddr() { - return AppendBytes(b, v.Slice(0, v.Len()).Bytes(), flags) - } - - buf := bufpool.Get(v.Len()) - - tmp := buf.Bytes() - reflect.Copy(reflect.ValueOf(tmp), v) - b = AppendBytes(b, tmp, flags) - - bufpool.Put(buf) - - return b -} - -func appendStringValue(b []byte, v reflect.Value, flags int) []byte { - return AppendString(b, v.String(), flags) -} - -func appendStructValue(b []byte, v reflect.Value, flags int) []byte { - if v.Type() == timeType { - return appendTimeValue(b, v, flags) - } - return appendJSONValue(b, v, flags) -} - -var jsonPool bufpool.Pool - -func appendJSONValue(b []byte, v reflect.Value, flags int) []byte { - buf := jsonPool.Get() - defer jsonPool.Put(buf) - - if err := pgjson.NewEncoder(buf).Encode(v.Interface()); err != nil { - return AppendError(b, err) - } - - bb := buf.Bytes() - if len(bb) > 0 && bb[len(bb)-1] == '\n' { - bb = bb[:len(bb)-1] - } - - return AppendJSONB(b, bb, flags) -} - -func appendTimeValue(b []byte, v reflect.Value, flags int) []byte { - tm := v.Interface().(time.Time) - return AppendTime(b, tm, flags) -} - -func appendIPValue(b []byte, v reflect.Value, flags int) []byte { - ip := v.Interface().(net.IP) - return AppendString(b, ip.String(), flags) -} - -func appendIPNetValue(b []byte, v reflect.Value, flags int) []byte { - ipnet := v.Interface().(net.IPNet) - return AppendString(b, ipnet.String(), flags) -} - -func appendJSONRawMessageValue(b []byte, v reflect.Value, flags int) []byte { - return AppendString(b, internal.BytesToString(v.Bytes()), flags) -} - -func appendAppenderValue(b []byte, v reflect.Value, flags int) []byte { - return appendAppender(b, v.Interface().(ValueAppender), flags) -} - -func appendDriverValuerValue(b []byte, v reflect.Value, flags int) []byte { - return appendDriverValuer(b, v.Interface().(driver.Valuer), flags) -} - -func appendDriverValuer(b []byte, v driver.Valuer, flags int) []byte { - value, err := v.Value() - if err != nil { - return AppendError(b, err) - } - return Append(b, value, flags) -} diff --git a/vendor/github.com/go-pg/pg/v10/types/array.go b/vendor/github.com/go-pg/pg/v10/types/array.go deleted file mode 100644 index fb70c1f50..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/array.go +++ /dev/null @@ -1,58 +0,0 @@ -package types - -import ( - "fmt" - "reflect" -) - -type Array struct { - v reflect.Value - - append AppenderFunc - scan ScannerFunc -} - -var ( - _ ValueAppender = (*Array)(nil) - _ ValueScanner = (*Array)(nil) -) - -func NewArray(vi interface{}) *Array { - v := reflect.ValueOf(vi) - if !v.IsValid() { - panic(fmt.Errorf("pg: Array(nil)")) - } - - return &Array{ - v: v, - - append: ArrayAppender(v.Type()), - scan: ArrayScanner(v.Type()), - } -} - -func (a *Array) AppendValue(b []byte, flags int) ([]byte, error) { - if a.append == nil { - panic(fmt.Errorf("pg: Array(unsupported %s)", a.v.Type())) - } - return a.append(b, a.v, flags), nil -} - -func (a *Array) ScanValue(rd Reader, n int) error { - if a.scan == nil { - return fmt.Errorf("pg: Array(unsupported %s)", a.v.Type()) - } - - if a.v.Kind() != reflect.Ptr { - return fmt.Errorf("pg: Array(non-pointer %s)", a.v.Type()) - } - - return a.scan(a.v.Elem(), rd, n) -} - -func (a *Array) Value() interface{} { - if a.v.IsValid() { - return a.v.Interface() - } - return nil -} diff --git a/vendor/github.com/go-pg/pg/v10/types/array_append.go b/vendor/github.com/go-pg/pg/v10/types/array_append.go deleted file mode 100644 index a4132eb61..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/array_append.go +++ /dev/null @@ -1,236 +0,0 @@ -package types - -import ( - "reflect" - "strconv" - "sync" -) - -var ( - stringType = reflect.TypeOf((*string)(nil)).Elem() - sliceStringType = reflect.TypeOf([]string(nil)) - - intType = reflect.TypeOf((*int)(nil)).Elem() - sliceIntType = reflect.TypeOf([]int(nil)) - - int64Type = reflect.TypeOf((*int64)(nil)).Elem() - sliceInt64Type = reflect.TypeOf([]int64(nil)) - - float64Type = reflect.TypeOf((*float64)(nil)).Elem() - sliceFloat64Type = reflect.TypeOf([]float64(nil)) -) - -var arrayAppendersMap sync.Map - -func ArrayAppender(typ reflect.Type) AppenderFunc { - if v, ok := arrayAppendersMap.Load(typ); ok { - return v.(AppenderFunc) - } - fn := arrayAppender(typ) - arrayAppendersMap.Store(typ, fn) - return fn -} - -func arrayAppender(typ reflect.Type) AppenderFunc { - kind := typ.Kind() - if kind == reflect.Ptr { - typ = typ.Elem() - kind = typ.Kind() - } - - switch kind { - case reflect.Slice, reflect.Array: - // ok: - default: - return nil - } - - elemType := typ.Elem() - - if kind == reflect.Slice { - switch elemType { - case stringType: - return appendSliceStringValue - case intType: - return appendSliceIntValue - case int64Type: - return appendSliceInt64Value - case float64Type: - return appendSliceFloat64Value - } - } - - appendElem := appender(elemType, true) - return func(b []byte, v reflect.Value, flags int) []byte { - flags |= arrayFlag - - kind := v.Kind() - switch kind { - case reflect.Ptr, reflect.Slice: - if v.IsNil() { - return AppendNull(b, flags) - } - } - - if kind == reflect.Ptr { - v = v.Elem() - } - - quote := shouldQuoteArray(flags) - if quote { - b = append(b, '\'') - } - - flags |= subArrayFlag - - b = append(b, '{') - for i := 0; i < v.Len(); i++ { - elem := v.Index(i) - b = appendElem(b, elem, flags) - b = append(b, ',') - } - if v.Len() > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - if quote { - b = append(b, '\'') - } - - return b - } -} - -func appendSliceStringValue(b []byte, v reflect.Value, flags int) []byte { - ss := v.Convert(sliceStringType).Interface().([]string) - return appendSliceString(b, ss, flags) -} - -func appendSliceString(b []byte, ss []string, flags int) []byte { - if ss == nil { - return AppendNull(b, flags) - } - - quote := shouldQuoteArray(flags) - if quote { - b = append(b, '\'') - } - - b = append(b, '{') - for _, s := range ss { - b = appendString2(b, s, flags) - b = append(b, ',') - } - if len(ss) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - if quote { - b = append(b, '\'') - } - - return b -} - -func appendSliceIntValue(b []byte, v reflect.Value, flags int) []byte { - ints := v.Convert(sliceIntType).Interface().([]int) - return appendSliceInt(b, ints, flags) -} - -func appendSliceInt(b []byte, ints []int, flags int) []byte { - if ints == nil { - return AppendNull(b, flags) - } - - quote := shouldQuoteArray(flags) - if quote { - b = append(b, '\'') - } - - b = append(b, '{') - for _, n := range ints { - b = strconv.AppendInt(b, int64(n), 10) - b = append(b, ',') - } - if len(ints) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - if quote { - b = append(b, '\'') - } - - return b -} - -func appendSliceInt64Value(b []byte, v reflect.Value, flags int) []byte { - ints := v.Convert(sliceInt64Type).Interface().([]int64) - return appendSliceInt64(b, ints, flags) -} - -func appendSliceInt64(b []byte, ints []int64, flags int) []byte { - if ints == nil { - return AppendNull(b, flags) - } - - quote := shouldQuoteArray(flags) - if quote { - b = append(b, '\'') - } - - b = append(b, '{') - for _, n := range ints { - b = strconv.AppendInt(b, n, 10) - b = append(b, ',') - } - if len(ints) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - if quote { - b = append(b, '\'') - } - - return b -} - -func appendSliceFloat64Value(b []byte, v reflect.Value, flags int) []byte { - floats := v.Convert(sliceFloat64Type).Interface().([]float64) - return appendSliceFloat64(b, floats, flags) -} - -func appendSliceFloat64(b []byte, floats []float64, flags int) []byte { - if floats == nil { - return AppendNull(b, flags) - } - - quote := shouldQuoteArray(flags) - if quote { - b = append(b, '\'') - } - - b = append(b, '{') - for _, n := range floats { - b = appendFloat2(b, n, flags) - b = append(b, ',') - } - if len(floats) > 0 { - b[len(b)-1] = '}' // Replace trailing comma. - } else { - b = append(b, '}') - } - - if quote { - b = append(b, '\'') - } - - return b -} diff --git a/vendor/github.com/go-pg/pg/v10/types/array_parser.go b/vendor/github.com/go-pg/pg/v10/types/array_parser.go deleted file mode 100644 index 0870a6568..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/array_parser.go +++ /dev/null @@ -1,170 +0,0 @@ -package types - -import ( - "bufio" - "bytes" - "errors" - "fmt" - "io" - - "github.com/go-pg/pg/v10/internal/parser" -) - -var errEndOfArray = errors.New("pg: end of array") - -type arrayParser struct { - p parser.StreamingParser - - stickyErr error - buf []byte -} - -func newArrayParserErr(err error) *arrayParser { - return &arrayParser{ - stickyErr: err, - buf: make([]byte, 32), - } -} - -func newArrayParser(rd Reader) *arrayParser { - p := parser.NewStreamingParser(rd) - err := p.SkipByte('{') - if err != nil { - return newArrayParserErr(err) - } - return &arrayParser{ - p: p, - } -} - -func (p *arrayParser) NextElem() ([]byte, error) { - if p.stickyErr != nil { - return nil, p.stickyErr - } - - c, err := p.p.ReadByte() - if err != nil { - if err == io.EOF { - return nil, errEndOfArray - } - return nil, err - } - - switch c { - case '"': - b, err := p.p.ReadSubstring(p.buf[:0]) - if err != nil { - return nil, err - } - p.buf = b - - err = p.readCommaBrace() - if err != nil { - return nil, err - } - - return b, nil - case '{': - b, err := p.readSubArray(p.buf[:0]) - if err != nil { - return nil, err - } - p.buf = b - - err = p.readCommaBrace() - if err != nil { - return nil, err - } - - return b, nil - case '}': - return nil, errEndOfArray - default: - err = p.p.UnreadByte() - if err != nil { - return nil, err - } - - b, err := p.readSimple(p.buf[:0]) - if err != nil { - return nil, err - } - p.buf = b - - if bytes.Equal(b, []byte("NULL")) { - return nil, nil - } - return b, nil - } -} - -func (p *arrayParser) readSimple(b []byte) ([]byte, error) { - for { - tmp, err := p.p.ReadSlice(',') - if err == nil { - b = append(b, tmp...) - b = b[:len(b)-1] - break - } - b = append(b, tmp...) - if err == bufio.ErrBufferFull { - continue - } - if err == io.EOF { - if b[len(b)-1] == '}' { - b = b[:len(b)-1] - break - } - } - return nil, err - } - return b, nil -} - -func (p *arrayParser) readSubArray(b []byte) ([]byte, error) { - b = append(b, '{') - for { - c, err := p.p.ReadByte() - if err != nil { - return nil, err - } - - if c == '}' { - b = append(b, '}') - return b, nil - } - - if c == '"' { - b = append(b, '"') - for { - tmp, err := p.p.ReadSlice('"') - b = append(b, tmp...) - if err != nil { - if err == bufio.ErrBufferFull { - continue - } - return nil, err - } - if len(b) > 1 && b[len(b)-2] != '\\' { - break - } - } - continue - } - - b = append(b, c) - } -} - -func (p *arrayParser) readCommaBrace() error { - c, err := p.p.ReadByte() - if err != nil { - return err - } - switch c { - case ',', '}': - return nil - default: - return fmt.Errorf("pg: got %q, wanted ',' or '}'", c) - } -} diff --git a/vendor/github.com/go-pg/pg/v10/types/array_scan.go b/vendor/github.com/go-pg/pg/v10/types/array_scan.go deleted file mode 100644 index dbccafc06..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/array_scan.go +++ /dev/null @@ -1,334 +0,0 @@ -package types - -import ( - "fmt" - "reflect" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/internal/pool" -) - -var arrayValueScannerType = reflect.TypeOf((*ArrayValueScanner)(nil)).Elem() - -type ArrayValueScanner interface { - BeforeScanArrayValue(rd Reader, n int) error - ScanArrayValue(rd Reader, n int) error - AfterScanArrayValue() error -} - -func ArrayScanner(typ reflect.Type) ScannerFunc { - if typ.Implements(arrayValueScannerType) { - return scanArrayValueScannerValue - } - - kind := typ.Kind() - if kind == reflect.Ptr { - typ = typ.Elem() - kind = typ.Kind() - } - - switch kind { - case reflect.Slice, reflect.Array: - // ok: - default: - return nil - } - - elemType := typ.Elem() - - if kind == reflect.Slice { - switch elemType { - case stringType: - return scanStringArrayValue - case intType: - return scanIntArrayValue - case int64Type: - return scanInt64ArrayValue - case float64Type: - return scanFloat64ArrayValue - } - } - - scanElem := scanner(elemType, true) - return func(v reflect.Value, rd Reader, n int) error { - v = reflect.Indirect(v) - if !v.CanSet() { - return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) - } - - kind := v.Kind() - - if n == -1 { - if kind != reflect.Slice || !v.IsNil() { - v.Set(reflect.Zero(v.Type())) - } - return nil - } - - if kind == reflect.Slice { - if v.IsNil() { - v.Set(reflect.MakeSlice(v.Type(), 0, 0)) - } else if v.Len() > 0 { - v.Set(v.Slice(0, 0)) - } - } - - p := newArrayParser(rd) - nextValue := internal.MakeSliceNextElemFunc(v) - var elemRd *pool.BytesReader - - for { - elem, err := p.NextElem() - if err != nil { - if err == errEndOfArray { - break - } - return err - } - - if elemRd == nil { - elemRd = pool.NewBytesReader(elem) - } else { - elemRd.Reset(elem) - } - - var elemN int - if elem == nil { - elemN = -1 - } else { - elemN = len(elem) - } - - elemValue := nextValue() - err = scanElem(elemValue, elemRd, elemN) - if err != nil { - return err - } - } - - return nil - } -} - -func scanStringArrayValue(v reflect.Value, rd Reader, n int) error { - v = reflect.Indirect(v) - if !v.CanSet() { - return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) - } - - strings, err := scanStringArray(rd, n) - if err != nil { - return err - } - - v.Set(reflect.ValueOf(strings)) - return nil -} - -func scanStringArray(rd Reader, n int) ([]string, error) { - if n == -1 { - return nil, nil - } - - p := newArrayParser(rd) - slice := make([]string, 0) - for { - elem, err := p.NextElem() - if err != nil { - if err == errEndOfArray { - break - } - return nil, err - } - - slice = append(slice, string(elem)) - } - - return slice, nil -} - -func scanIntArrayValue(v reflect.Value, rd Reader, n int) error { - v = reflect.Indirect(v) - if !v.CanSet() { - return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) - } - - slice, err := decodeSliceInt(rd, n) - if err != nil { - return err - } - - v.Set(reflect.ValueOf(slice)) - return nil -} - -func decodeSliceInt(rd Reader, n int) ([]int, error) { - if n == -1 { - return nil, nil - } - - p := newArrayParser(rd) - slice := make([]int, 0) - for { - elem, err := p.NextElem() - if err != nil { - if err == errEndOfArray { - break - } - return nil, err - } - - if elem == nil { - slice = append(slice, 0) - continue - } - - n, err := internal.Atoi(elem) - if err != nil { - return nil, err - } - - slice = append(slice, n) - } - - return slice, nil -} - -func scanInt64ArrayValue(v reflect.Value, rd Reader, n int) error { - v = reflect.Indirect(v) - if !v.CanSet() { - return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) - } - - slice, err := scanInt64Array(rd, n) - if err != nil { - return err - } - - v.Set(reflect.ValueOf(slice)) - return nil -} - -func scanInt64Array(rd Reader, n int) ([]int64, error) { - if n == -1 { - return nil, nil - } - - p := newArrayParser(rd) - slice := make([]int64, 0) - for { - elem, err := p.NextElem() - if err != nil { - if err == errEndOfArray { - break - } - return nil, err - } - - if elem == nil { - slice = append(slice, 0) - continue - } - - n, err := internal.ParseInt(elem, 10, 64) - if err != nil { - return nil, err - } - - slice = append(slice, n) - } - - return slice, nil -} - -func scanFloat64ArrayValue(v reflect.Value, rd Reader, n int) error { - v = reflect.Indirect(v) - if !v.CanSet() { - return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) - } - - slice, err := scanFloat64Array(rd, n) - if err != nil { - return err - } - - v.Set(reflect.ValueOf(slice)) - return nil -} - -func scanFloat64Array(rd Reader, n int) ([]float64, error) { - if n == -1 { - return nil, nil - } - - p := newArrayParser(rd) - slice := make([]float64, 0) - for { - elem, err := p.NextElem() - if err != nil { - if err == errEndOfArray { - break - } - return nil, err - } - - if elem == nil { - slice = append(slice, 0) - continue - } - - n, err := internal.ParseFloat(elem, 64) - if err != nil { - return nil, err - } - - slice = append(slice, n) - } - - return slice, nil -} - -func scanArrayValueScannerValue(v reflect.Value, rd Reader, n int) error { - if n == -1 { - return nil - } - - scanner := v.Addr().Interface().(ArrayValueScanner) - - err := scanner.BeforeScanArrayValue(rd, n) - if err != nil { - return err - } - - p := newArrayParser(rd) - var elemRd *pool.BytesReader - for { - elem, err := p.NextElem() - if err != nil { - if err == errEndOfArray { - break - } - return err - } - - if elemRd == nil { - elemRd = pool.NewBytesReader(elem) - } else { - elemRd.Reset(elem) - } - - var elemN int - if elem == nil { - elemN = -1 - } else { - elemN = len(elem) - } - - err = scanner.ScanArrayValue(elemRd, elemN) - if err != nil { - return err - } - } - - return scanner.AfterScanArrayValue() -} diff --git a/vendor/github.com/go-pg/pg/v10/types/column.go b/vendor/github.com/go-pg/pg/v10/types/column.go deleted file mode 100644 index e3470f3eb..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/column.go +++ /dev/null @@ -1,113 +0,0 @@ -package types - -import ( - "encoding/json" - - "github.com/go-pg/pg/v10/internal/pool" - "github.com/go-pg/pg/v10/pgjson" -) - -const ( - pgBool = 16 - - pgInt2 = 21 - pgInt4 = 23 - pgInt8 = 20 - - pgFloat4 = 700 - pgFloat8 = 701 - - pgText = 25 - pgVarchar = 1043 - pgBytea = 17 - pgJSON = 114 - pgJSONB = 3802 - - pgTimestamp = 1114 - pgTimestamptz = 1184 - - // pgInt2Array = 1005 - pgInt32Array = 1007 - pgInt8Array = 1016 - pgFloat8Array = 1022 - pgStringArray = 1009 - - pgUUID = 2950 -) - -type ColumnInfo = pool.ColumnInfo - -type RawValue struct { - Type int32 - Value string -} - -func (v RawValue) AppendValue(b []byte, flags int) ([]byte, error) { - return AppendString(b, v.Value, flags), nil -} - -func (v RawValue) MarshalJSON() ([]byte, error) { - return pgjson.Marshal(v.Value) -} - -func ReadColumnValue(col ColumnInfo, rd Reader, n int) (interface{}, error) { - switch col.DataType { - case pgBool: - return ScanBool(rd, n) - - case pgInt2: - n, err := scanInt64(rd, n, 16) - if err != nil { - return nil, err - } - return int16(n), nil - case pgInt4: - n, err := scanInt64(rd, n, 32) - if err != nil { - return nil, err - } - return int32(n), nil - case pgInt8: - return ScanInt64(rd, n) - - case pgFloat4: - return ScanFloat32(rd, n) - case pgFloat8: - return ScanFloat64(rd, n) - - case pgBytea: - return ScanBytes(rd, n) - case pgText, pgVarchar, pgUUID: - return ScanString(rd, n) - case pgJSON, pgJSONB: - s, err := ScanString(rd, n) - if err != nil { - return nil, err - } - return json.RawMessage(s), nil - - case pgTimestamp: - return ScanTime(rd, n) - case pgTimestamptz: - return ScanTime(rd, n) - - case pgInt32Array: - return scanInt64Array(rd, n) - case pgInt8Array: - return scanInt64Array(rd, n) - case pgFloat8Array: - return scanFloat64Array(rd, n) - case pgStringArray: - return scanStringArray(rd, n) - - default: - s, err := ScanString(rd, n) - if err != nil { - return nil, err - } - return RawValue{ - Type: col.DataType, - Value: s, - }, nil - } -} diff --git a/vendor/github.com/go-pg/pg/v10/types/doc.go b/vendor/github.com/go-pg/pg/v10/types/doc.go deleted file mode 100644 index 890ef3c08..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/doc.go +++ /dev/null @@ -1,4 +0,0 @@ -/* -The API in this package is not stable and may change without any notice. -*/ -package types diff --git a/vendor/github.com/go-pg/pg/v10/types/flags.go b/vendor/github.com/go-pg/pg/v10/types/flags.go deleted file mode 100644 index 10e415f14..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/flags.go +++ /dev/null @@ -1,25 +0,0 @@ -package types - -import "reflect" - -const ( - quoteFlag = 1 << iota - arrayFlag - subArrayFlag -) - -func hasFlag(flags, flag int) bool { - return flags&flag == flag -} - -func shouldQuoteArray(flags int) bool { - return hasFlag(flags, quoteFlag) && !hasFlag(flags, subArrayFlag) -} - -func nilable(v reflect.Value) bool { - switch v.Kind() { - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: - return true - } - return false -} diff --git a/vendor/github.com/go-pg/pg/v10/types/hex.go b/vendor/github.com/go-pg/pg/v10/types/hex.go deleted file mode 100644 index 8ae6469b9..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/hex.go +++ /dev/null @@ -1,81 +0,0 @@ -package types - -import ( - "bytes" - "encoding/hex" - "fmt" - "io" - - fasthex "github.com/tmthrgd/go-hex" -) - -type HexEncoder struct { - b []byte - flags int - written bool -} - -func NewHexEncoder(b []byte, flags int) *HexEncoder { - return &HexEncoder{ - b: b, - flags: flags, - } -} - -func (enc *HexEncoder) Bytes() []byte { - return enc.b -} - -func (enc *HexEncoder) Write(b []byte) (int, error) { - if !enc.written { - if hasFlag(enc.flags, arrayFlag) { - enc.b = append(enc.b, `"\`...) - } else if hasFlag(enc.flags, quoteFlag) { - enc.b = append(enc.b, '\'') - } - enc.b = append(enc.b, `\x`...) - enc.written = true - } - - i := len(enc.b) - enc.b = append(enc.b, make([]byte, fasthex.EncodedLen(len(b)))...) - fasthex.Encode(enc.b[i:], b) - - return len(b), nil -} - -func (enc *HexEncoder) Close() error { - if enc.written { - if hasFlag(enc.flags, arrayFlag) { - enc.b = append(enc.b, '"') - } else if hasFlag(enc.flags, quoteFlag) { - enc.b = append(enc.b, '\'') - } - } else { - enc.b = AppendNull(enc.b, enc.flags) - } - return nil -} - -//------------------------------------------------------------------------------ - -func NewHexDecoder(rd Reader, n int) (io.Reader, error) { - if n <= 0 { - var rd bytes.Reader - return &rd, nil - } - - if c, err := rd.ReadByte(); err != nil { - return nil, err - } else if c != '\\' { - return nil, fmt.Errorf("got %q, wanted %q", c, '\\') - } - - if c, err := rd.ReadByte(); err != nil { - return nil, err - } else if c != 'x' { - return nil, fmt.Errorf("got %q, wanted %q", c, 'x') - } - - return hex.NewDecoder(rd), nil -} diff --git a/vendor/github.com/go-pg/pg/v10/types/hstore.go b/vendor/github.com/go-pg/pg/v10/types/hstore.go deleted file mode 100644 index 58c214ac6..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/hstore.go +++ /dev/null @@ -1,59 +0,0 @@ -package types - -import ( - "fmt" - "reflect" -) - -type Hstore struct { - v reflect.Value - - append AppenderFunc - scan ScannerFunc -} - -var ( - _ ValueAppender = (*Hstore)(nil) - _ ValueScanner = (*Hstore)(nil) -) - -func NewHstore(vi interface{}) *Hstore { - v := reflect.ValueOf(vi) - if !v.IsValid() { - panic(fmt.Errorf("pg.Hstore(nil)")) - } - - typ := v.Type() - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - if typ.Kind() != reflect.Map { - panic(fmt.Errorf("pg.Hstore(unsupported %s)", typ)) - } - - return &Hstore{ - v: v, - - append: HstoreAppender(typ), - scan: HstoreScanner(typ), - } -} - -func (h *Hstore) Value() interface{} { - if h.v.IsValid() { - return h.v.Interface() - } - return nil -} - -func (h *Hstore) AppendValue(b []byte, flags int) ([]byte, error) { - return h.append(b, h.v, flags), nil -} - -func (h *Hstore) ScanValue(rd Reader, n int) error { - if h.v.Kind() != reflect.Ptr { - return fmt.Errorf("pg: Hstore(non-pointer %s)", h.v.Type()) - } - - return h.scan(h.v.Elem(), rd, n) -} diff --git a/vendor/github.com/go-pg/pg/v10/types/hstore_append.go b/vendor/github.com/go-pg/pg/v10/types/hstore_append.go deleted file mode 100644 index e27292afa..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/hstore_append.go +++ /dev/null @@ -1,50 +0,0 @@ -package types - -import ( - "fmt" - "reflect" -) - -var mapStringStringType = reflect.TypeOf(map[string]string(nil)) - -func HstoreAppender(typ reflect.Type) AppenderFunc { - if typ.Key() == stringType && typ.Elem() == stringType { - return appendMapStringStringValue - } - - return func(b []byte, v reflect.Value, flags int) []byte { - err := fmt.Errorf("pg.Hstore(unsupported %s)", v.Type()) - return AppendError(b, err) - } -} - -func appendMapStringString(b []byte, m map[string]string, flags int) []byte { - if m == nil { - return AppendNull(b, flags) - } - - if hasFlag(flags, quoteFlag) { - b = append(b, '\'') - } - - for key, value := range m { - b = appendString2(b, key, flags) - b = append(b, '=', '>') - b = appendString2(b, value, flags) - b = append(b, ',') - } - if len(m) > 0 { - b = b[:len(b)-1] // Strip trailing comma. - } - - if hasFlag(flags, quoteFlag) { - b = append(b, '\'') - } - - return b -} - -func appendMapStringStringValue(b []byte, v reflect.Value, flags int) []byte { - m := v.Convert(mapStringStringType).Interface().(map[string]string) - return appendMapStringString(b, m, flags) -} diff --git a/vendor/github.com/go-pg/pg/v10/types/hstore_parser.go b/vendor/github.com/go-pg/pg/v10/types/hstore_parser.go deleted file mode 100644 index 79cd41eda..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/hstore_parser.go +++ /dev/null @@ -1,65 +0,0 @@ -package types - -import ( - "errors" - "io" - - "github.com/go-pg/pg/v10/internal/parser" -) - -var errEndOfHstore = errors.New("pg: end of hstore") - -type hstoreParser struct { - p parser.StreamingParser -} - -func newHstoreParser(rd Reader) *hstoreParser { - return &hstoreParser{ - p: parser.NewStreamingParser(rd), - } -} - -func (p *hstoreParser) NextKey() ([]byte, error) { - err := p.p.SkipByte('"') - if err != nil { - if err == io.EOF { - return nil, errEndOfHstore - } - return nil, err - } - - key, err := p.p.ReadSubstring(nil) - if err != nil { - return nil, err - } - - err = p.p.SkipByte('=') - if err != nil { - return nil, err - } - err = p.p.SkipByte('>') - if err != nil { - return nil, err - } - - return key, nil -} - -func (p *hstoreParser) NextValue() ([]byte, error) { - err := p.p.SkipByte('"') - if err != nil { - return nil, err - } - - value, err := p.p.ReadSubstring(nil) - if err != nil { - return nil, err - } - - err = p.p.SkipByte(',') - if err == nil { - _ = p.p.SkipByte(' ') - } - - return value, nil -} diff --git a/vendor/github.com/go-pg/pg/v10/types/hstore_scan.go b/vendor/github.com/go-pg/pg/v10/types/hstore_scan.go deleted file mode 100644 index 2061c6163..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/hstore_scan.go +++ /dev/null @@ -1,51 +0,0 @@ -package types - -import ( - "fmt" - "reflect" -) - -func HstoreScanner(typ reflect.Type) ScannerFunc { - if typ.Key() == stringType && typ.Elem() == stringType { - return scanMapStringStringValue - } - return func(v reflect.Value, rd Reader, n int) error { - return fmt.Errorf("pg.Hstore(unsupported %s)", v.Type()) - } -} - -func scanMapStringStringValue(v reflect.Value, rd Reader, n int) error { - m, err := scanMapStringString(rd, n) - if err != nil { - return err - } - - v.Set(reflect.ValueOf(m)) - return nil -} - -func scanMapStringString(rd Reader, n int) (map[string]string, error) { - if n == -1 { - return nil, nil - } - - p := newHstoreParser(rd) - m := make(map[string]string) - for { - key, err := p.NextKey() - if err != nil { - if err == errEndOfHstore { - break - } - return nil, err - } - - value, err := p.NextValue() - if err != nil { - return nil, err - } - - m[string(key)] = string(value) - } - return m, nil -} diff --git a/vendor/github.com/go-pg/pg/v10/types/in_op.go b/vendor/github.com/go-pg/pg/v10/types/in_op.go deleted file mode 100644 index 472b986d8..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/in_op.go +++ /dev/null @@ -1,62 +0,0 @@ -package types - -import ( - "fmt" - "reflect" -) - -type inOp struct { - slice reflect.Value - stickyErr error -} - -var _ ValueAppender = (*inOp)(nil) - -func InMulti(values ...interface{}) ValueAppender { - return &inOp{ - slice: reflect.ValueOf(values), - } -} - -func In(slice interface{}) ValueAppender { - v := reflect.ValueOf(slice) - if v.Kind() != reflect.Slice { - return &inOp{ - stickyErr: fmt.Errorf("pg: In(non-slice %T)", slice), - } - } - - return &inOp{ - slice: v, - } -} - -func (in *inOp) AppendValue(b []byte, flags int) ([]byte, error) { - if in.stickyErr != nil { - return nil, in.stickyErr - } - return appendIn(b, in.slice, flags), nil -} - -func appendIn(b []byte, slice reflect.Value, flags int) []byte { - sliceLen := slice.Len() - for i := 0; i < sliceLen; i++ { - if i > 0 { - b = append(b, ',') - } - - elem := slice.Index(i) - if elem.Kind() == reflect.Interface { - elem = elem.Elem() - } - - if elem.Kind() == reflect.Slice { - b = append(b, '(') - b = appendIn(b, elem, flags) - b = append(b, ')') - } else { - b = appendValue(b, elem, flags) - } - } - return b -} diff --git a/vendor/github.com/go-pg/pg/v10/types/null_time.go b/vendor/github.com/go-pg/pg/v10/types/null_time.go deleted file mode 100644 index 3c3f1f79a..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/null_time.go +++ /dev/null @@ -1,58 +0,0 @@ -package types - -import ( - "bytes" - "database/sql" - "encoding/json" - "time" -) - -var jsonNull = []byte("null") - -// NullTime is a time.Time wrapper that marshals zero time as JSON null and -// PostgreSQL NULL. -type NullTime struct { - time.Time -} - -var ( - _ json.Marshaler = (*NullTime)(nil) - _ json.Unmarshaler = (*NullTime)(nil) - _ sql.Scanner = (*NullTime)(nil) - _ ValueAppender = (*NullTime)(nil) -) - -func (tm NullTime) MarshalJSON() ([]byte, error) { - if tm.IsZero() { - return jsonNull, nil - } - return tm.Time.MarshalJSON() -} - -func (tm *NullTime) UnmarshalJSON(b []byte) error { - if bytes.Equal(b, jsonNull) { - tm.Time = time.Time{} - return nil - } - return tm.Time.UnmarshalJSON(b) -} - -func (tm NullTime) AppendValue(b []byte, flags int) ([]byte, error) { - if tm.IsZero() { - return AppendNull(b, flags), nil - } - return AppendTime(b, tm.Time, flags), nil -} - -func (tm *NullTime) Scan(b interface{}) error { - if b == nil { - tm.Time = time.Time{} - return nil - } - newtm, err := ParseTime(b.([]byte)) - if err != nil { - return err - } - tm.Time = newtm - return nil -} diff --git a/vendor/github.com/go-pg/pg/v10/types/scan.go b/vendor/github.com/go-pg/pg/v10/types/scan.go deleted file mode 100644 index 2e9c0cc85..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/scan.go +++ /dev/null @@ -1,244 +0,0 @@ -package types - -import ( - "errors" - "fmt" - "reflect" - "time" - - "github.com/tmthrgd/go-hex" - - "github.com/go-pg/pg/v10/internal" -) - -func Scan(v interface{}, rd Reader, n int) error { - var err error - switch v := v.(type) { - case *string: - *v, err = ScanString(rd, n) - return err - case *[]byte: - *v, err = ScanBytes(rd, n) - return err - case *int: - *v, err = ScanInt(rd, n) - return err - case *int64: - *v, err = ScanInt64(rd, n) - return err - case *float32: - *v, err = ScanFloat32(rd, n) - return err - case *float64: - *v, err = ScanFloat64(rd, n) - return err - case *time.Time: - *v, err = ScanTime(rd, n) - return err - } - - vv := reflect.ValueOf(v) - if !vv.IsValid() { - return errors.New("pg: Scan(nil)") - } - - if vv.Kind() != reflect.Ptr { - return fmt.Errorf("pg: Scan(non-pointer %T)", v) - } - if vv.IsNil() { - return fmt.Errorf("pg: Scan(non-settable %T)", v) - } - - vv = vv.Elem() - if vv.Kind() == reflect.Interface { - if vv.IsNil() { - return errors.New("pg: Scan(nil)") - } - - vv = vv.Elem() - if vv.Kind() != reflect.Ptr { - return fmt.Errorf("pg: Decode(non-pointer %s)", vv.Type().String()) - } - } - - return ScanValue(vv, rd, n) -} - -func ScanString(rd Reader, n int) (string, error) { - if n <= 0 { - return "", nil - } - - b, err := rd.ReadFull() - if err != nil { - return "", err - } - - return internal.BytesToString(b), nil -} - -func ScanBytes(rd Reader, n int) ([]byte, error) { - if n == -1 { - return nil, nil - } - if n == 0 { - return []byte{}, nil - } - - b := make([]byte, hex.DecodedLen(n-2)) - if err := ReadBytes(rd, b); err != nil { - return nil, err - } - return b, nil -} - -func ReadBytes(rd Reader, b []byte) error { - tmp, err := rd.ReadFullTemp() - if err != nil { - return err - } - - if len(tmp) < 2 { - return fmt.Errorf("pg: can't parse bytea: %q", tmp) - } - - if tmp[0] != '\\' || tmp[1] != 'x' { - return fmt.Errorf("pg: can't parse bytea: %q", tmp) - } - tmp = tmp[2:] // Trim off "\\x". - - if len(b) != hex.DecodedLen(len(tmp)) { - return fmt.Errorf("pg: too small buf to decode hex") - } - - if _, err := hex.Decode(b, tmp); err != nil { - return err - } - - return nil -} - -func ScanInt(rd Reader, n int) (int, error) { - if n <= 0 { - return 0, nil - } - - tmp, err := rd.ReadFullTemp() - if err != nil { - return 0, err - } - - num, err := internal.Atoi(tmp) - if err != nil { - return 0, err - } - - return num, nil -} - -func ScanInt64(rd Reader, n int) (int64, error) { - return scanInt64(rd, n, 64) -} - -func scanInt64(rd Reader, n int, bitSize int) (int64, error) { - if n <= 0 { - return 0, nil - } - - tmp, err := rd.ReadFullTemp() - if err != nil { - return 0, err - } - - num, err := internal.ParseInt(tmp, 10, bitSize) - if err != nil { - return 0, err - } - - return num, nil -} - -func ScanUint64(rd Reader, n int) (uint64, error) { - if n <= 0 { - return 0, nil - } - - tmp, err := rd.ReadFullTemp() - if err != nil { - return 0, err - } - - // PostgreSQL does not natively support uint64 - only int64. - // Be nice and accept negative int64. - if len(tmp) > 0 && tmp[0] == '-' { - num, err := internal.ParseInt(tmp, 10, 64) - if err != nil { - return 0, err - } - return uint64(num), nil - } - - num, err := internal.ParseUint(tmp, 10, 64) - if err != nil { - return 0, err - } - - return num, nil -} - -func ScanFloat32(rd Reader, n int) (float32, error) { - if n <= 0 { - return 0, nil - } - - tmp, err := rd.ReadFullTemp() - if err != nil { - return 0, err - } - - num, err := internal.ParseFloat(tmp, 32) - if err != nil { - return 0, err - } - - return float32(num), nil -} - -func ScanFloat64(rd Reader, n int) (float64, error) { - if n <= 0 { - return 0, nil - } - - tmp, err := rd.ReadFullTemp() - if err != nil { - return 0, err - } - - num, err := internal.ParseFloat(tmp, 64) - if err != nil { - return 0, err - } - - return num, nil -} - -func ScanTime(rd Reader, n int) (time.Time, error) { - if n <= 0 { - return time.Time{}, nil - } - - tmp, err := rd.ReadFullTemp() - if err != nil { - return time.Time{}, err - } - - return ParseTime(tmp) -} - -func ScanBool(rd Reader, n int) (bool, error) { - tmp, err := rd.ReadFullTemp() - if err != nil { - return false, err - } - return len(tmp) == 1 && (tmp[0] == 't' || tmp[0] == '1'), nil -} diff --git a/vendor/github.com/go-pg/pg/v10/types/scan_value.go b/vendor/github.com/go-pg/pg/v10/types/scan_value.go deleted file mode 100644 index 9f5a7bb6e..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/scan_value.go +++ /dev/null @@ -1,418 +0,0 @@ -package types - -import ( - "database/sql" - "encoding/json" - "errors" - "fmt" - "net" - "reflect" - "sync" - "time" - - "github.com/go-pg/pg/v10/internal" - "github.com/go-pg/pg/v10/pgjson" -) - -var ( - valueScannerType = reflect.TypeOf((*ValueScanner)(nil)).Elem() - sqlScannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() - timeType = reflect.TypeOf((*time.Time)(nil)).Elem() - ipType = reflect.TypeOf((*net.IP)(nil)).Elem() - ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() - jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() -) - -type ScannerFunc func(reflect.Value, Reader, int) error - -var valueScanners []ScannerFunc - -//nolint -func init() { - valueScanners = []ScannerFunc{ - reflect.Bool: scanBoolValue, - reflect.Int: scanInt64Value, - reflect.Int8: scanInt64Value, - reflect.Int16: scanInt64Value, - reflect.Int32: scanInt64Value, - reflect.Int64: scanInt64Value, - reflect.Uint: scanUint64Value, - reflect.Uint8: scanUint64Value, - reflect.Uint16: scanUint64Value, - reflect.Uint32: scanUint64Value, - reflect.Uint64: scanUint64Value, - reflect.Uintptr: nil, - reflect.Float32: scanFloat32Value, - reflect.Float64: scanFloat64Value, - reflect.Complex64: nil, - reflect.Complex128: nil, - reflect.Array: scanJSONValue, - reflect.Chan: nil, - reflect.Func: nil, - reflect.Interface: scanIfaceValue, - reflect.Map: scanJSONValue, - reflect.Ptr: nil, - reflect.Slice: scanJSONValue, - reflect.String: scanStringValue, - reflect.Struct: scanJSONValue, - reflect.UnsafePointer: nil, - } -} - -var scannersMap sync.Map - -// RegisterScanner registers an scanner func for the type. -// Expecting to be used only during initialization, it panics -// if there is already a registered scanner for the given type. -func RegisterScanner(value interface{}, fn ScannerFunc) { - registerScanner(reflect.TypeOf(value), fn) -} - -func registerScanner(typ reflect.Type, fn ScannerFunc) { - _, loaded := scannersMap.LoadOrStore(typ, fn) - if loaded { - err := fmt.Errorf("pg: scanner for the type=%s is already registered", - typ.String()) - panic(err) - } -} - -func Scanner(typ reflect.Type) ScannerFunc { - if v, ok := scannersMap.Load(typ); ok { - return v.(ScannerFunc) - } - fn := scanner(typ, false) - _, _ = scannersMap.LoadOrStore(typ, fn) - return fn -} - -func scanner(typ reflect.Type, pgArray bool) ScannerFunc { - switch typ { - case timeType: - return scanTimeValue - case ipType: - return scanIPValue - case ipNetType: - return scanIPNetValue - case jsonRawMessageType: - return scanJSONRawMessageValue - } - - if typ.Implements(valueScannerType) { - return scanValueScannerValue - } - if reflect.PtrTo(typ).Implements(valueScannerType) { - return scanValueScannerAddrValue - } - - if typ.Implements(sqlScannerType) { - return scanSQLScannerValue - } - if reflect.PtrTo(typ).Implements(sqlScannerType) { - return scanSQLScannerAddrValue - } - - kind := typ.Kind() - switch kind { - case reflect.Ptr: - return ptrScannerFunc(typ) - case reflect.Slice: - if typ.Elem().Kind() == reflect.Uint8 { - return scanBytesValue - } - if pgArray { - return ArrayScanner(typ) - } - case reflect.Array: - if typ.Elem().Kind() == reflect.Uint8 { - return scanArrayBytesValue - } - } - return valueScanners[kind] -} - -func ptrScannerFunc(typ reflect.Type) ScannerFunc { - scanner := Scanner(typ.Elem()) - return func(v reflect.Value, rd Reader, n int) error { - if scanner == nil { - return fmt.Errorf("pg: Scan(unsupported %s)", v.Type()) - } - - if n == -1 { - if v.IsNil() { - return nil - } - if !v.CanSet() { - return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) - } - v.Set(reflect.Zero(v.Type())) - return nil - } - - if v.IsNil() { - if !v.CanSet() { - return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) - } - v.Set(reflect.New(v.Type().Elem())) - } - - return scanner(v.Elem(), rd, n) - } -} - -func scanIfaceValue(v reflect.Value, rd Reader, n int) error { - if v.IsNil() { - return scanJSONValue(v, rd, n) - } - return ScanValue(v.Elem(), rd, n) -} - -func ScanValue(v reflect.Value, rd Reader, n int) error { - if !v.IsValid() { - return errors.New("pg: Scan(nil)") - } - - scanner := Scanner(v.Type()) - if scanner != nil { - return scanner(v, rd, n) - } - - if v.Kind() == reflect.Interface { - return errors.New("pg: Scan(nil)") - } - return fmt.Errorf("pg: Scan(unsupported %s)", v.Type()) -} - -func scanBoolValue(v reflect.Value, rd Reader, n int) error { - if n == -1 { - v.SetBool(false) - return nil - } - - flag, err := ScanBool(rd, n) - if err != nil { - return err - } - - v.SetBool(flag) - return nil -} - -func scanInt64Value(v reflect.Value, rd Reader, n int) error { - num, err := ScanInt64(rd, n) - if err != nil { - return err - } - - v.SetInt(num) - return nil -} - -func scanUint64Value(v reflect.Value, rd Reader, n int) error { - num, err := ScanUint64(rd, n) - if err != nil { - return err - } - - v.SetUint(num) - return nil -} - -func scanFloat32Value(v reflect.Value, rd Reader, n int) error { - num, err := ScanFloat32(rd, n) - if err != nil { - return err - } - - v.SetFloat(float64(num)) - return nil -} - -func scanFloat64Value(v reflect.Value, rd Reader, n int) error { - num, err := ScanFloat64(rd, n) - if err != nil { - return err - } - - v.SetFloat(num) - return nil -} - -func scanStringValue(v reflect.Value, rd Reader, n int) error { - s, err := ScanString(rd, n) - if err != nil { - return err - } - - v.SetString(s) - return nil -} - -func scanJSONValue(v reflect.Value, rd Reader, n int) error { - // Zero value so it works with SelectOrInsert. - // TODO: better handle slices - v.Set(reflect.New(v.Type()).Elem()) - - if n == -1 { - return nil - } - - dec := pgjson.NewDecoder(rd) - return dec.Decode(v.Addr().Interface()) -} - -func scanTimeValue(v reflect.Value, rd Reader, n int) error { - tm, err := ScanTime(rd, n) - if err != nil { - return err - } - - ptr := v.Addr().Interface().(*time.Time) - *ptr = tm - - return nil -} - -func scanIPValue(v reflect.Value, rd Reader, n int) error { - if n == -1 { - return nil - } - - tmp, err := rd.ReadFullTemp() - if err != nil { - return err - } - - ip := net.ParseIP(internal.BytesToString(tmp)) - if ip == nil { - return fmt.Errorf("pg: invalid ip=%q", tmp) - } - - ptr := v.Addr().Interface().(*net.IP) - *ptr = ip - - return nil -} - -var zeroIPNetValue = reflect.ValueOf(net.IPNet{}) - -func scanIPNetValue(v reflect.Value, rd Reader, n int) error { - if n == -1 { - v.Set(zeroIPNetValue) - return nil - } - - tmp, err := rd.ReadFullTemp() - if err != nil { - return err - } - - _, ipnet, err := net.ParseCIDR(internal.BytesToString(tmp)) - if err != nil { - return err - } - - ptr := v.Addr().Interface().(*net.IPNet) - *ptr = *ipnet - - return nil -} - -func scanJSONRawMessageValue(v reflect.Value, rd Reader, n int) error { - if n == -1 { - v.SetBytes(nil) - return nil - } - - b, err := rd.ReadFull() - if err != nil { - return err - } - - v.SetBytes(b) - return nil -} - -func scanBytesValue(v reflect.Value, rd Reader, n int) error { - if n == -1 { - v.SetBytes(nil) - return nil - } - - b, err := ScanBytes(rd, n) - if err != nil { - return err - } - - v.SetBytes(b) - return nil -} - -func scanArrayBytesValue(v reflect.Value, rd Reader, n int) error { - b := v.Slice(0, v.Len()).Bytes() - - if n == -1 { - for i := range b { - b[i] = 0 - } - return nil - } - - return ReadBytes(rd, b) -} - -func scanValueScannerValue(v reflect.Value, rd Reader, n int) error { - if n == -1 { - if v.IsNil() { - return nil - } - return v.Interface().(ValueScanner).ScanValue(rd, n) - } - - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - - return v.Interface().(ValueScanner).ScanValue(rd, n) -} - -func scanValueScannerAddrValue(v reflect.Value, rd Reader, n int) error { - if !v.CanAddr() { - return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) - } - return v.Addr().Interface().(ValueScanner).ScanValue(rd, n) -} - -func scanSQLScannerValue(v reflect.Value, rd Reader, n int) error { - if n == -1 { - if nilable(v) && v.IsNil() { - return nil - } - return scanSQLScanner(v.Interface().(sql.Scanner), rd, n) - } - - if nilable(v) && v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - - return scanSQLScanner(v.Interface().(sql.Scanner), rd, n) -} - -func scanSQLScannerAddrValue(v reflect.Value, rd Reader, n int) error { - if !v.CanAddr() { - return fmt.Errorf("pg: Scan(non-settable %s)", v.Type()) - } - return scanSQLScanner(v.Addr().Interface().(sql.Scanner), rd, n) -} - -func scanSQLScanner(scanner sql.Scanner, rd Reader, n int) error { - if n == -1 { - return scanner.Scan(nil) - } - - tmp, err := rd.ReadFullTemp() - if err != nil { - return err - } - return scanner.Scan(tmp) -} diff --git a/vendor/github.com/go-pg/pg/v10/types/types.go b/vendor/github.com/go-pg/pg/v10/types/types.go deleted file mode 100644 index 718ac2933..000000000 --- a/vendor/github.com/go-pg/pg/v10/types/types.go +++ /dev/null @@ -1,37 +0,0 @@ -package types - -import ( - "github.com/go-pg/pg/v10/internal/pool" -) - -type Reader = pool.Reader - -type ValueScanner interface { - ScanValue(rd Reader, n int) error -} - -type ValueAppender interface { - AppendValue(b []byte, flags int) ([]byte, error) -} - -//------------------------------------------------------------------------------ - -// Safe represents a safe SQL query. -type Safe string - -var _ ValueAppender = (*Safe)(nil) - -func (q Safe) AppendValue(b []byte, flags int) ([]byte, error) { - return append(b, q...), nil -} - -//------------------------------------------------------------------------------ - -// Ident represents a SQL identifier, e.g. table or column name. -type Ident string - -var _ ValueAppender = (*Ident)(nil) - -func (f Ident) AppendValue(b []byte, flags int) ([]byte, error) { - return AppendIdent(b, string(f), flags), nil -} diff --git a/vendor/github.com/go-pg/zerochecker/LICENSE b/vendor/github.com/go-pg/zerochecker/LICENSE deleted file mode 100644 index 7751509b8..000000000 --- a/vendor/github.com/go-pg/zerochecker/LICENSE +++ /dev/null @@ -1,24 +0,0 @@ -Copyright (c) 2013 github.com/go-pg/pg Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/go-pg/zerochecker/go.mod b/vendor/github.com/go-pg/zerochecker/go.mod deleted file mode 100644 index f01e8a9f4..000000000 --- a/vendor/github.com/go-pg/zerochecker/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module github.com/go-pg/zerochecker - -go 1.13 diff --git a/vendor/github.com/jackc/chunkreader/v2/.travis.yml b/vendor/github.com/jackc/chunkreader/v2/.travis.yml new file mode 100644 index 000000000..e176228e8 --- /dev/null +++ b/vendor/github.com/jackc/chunkreader/v2/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - 1.x + - tip + +matrix: + allow_failures: + - go: tip diff --git a/vendor/github.com/jackc/chunkreader/v2/LICENSE b/vendor/github.com/jackc/chunkreader/v2/LICENSE new file mode 100644 index 000000000..c1c4f50fc --- /dev/null +++ b/vendor/github.com/jackc/chunkreader/v2/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2019 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/jackc/chunkreader/v2/README.md b/vendor/github.com/jackc/chunkreader/v2/README.md new file mode 100644 index 000000000..01209bfa2 --- /dev/null +++ b/vendor/github.com/jackc/chunkreader/v2/README.md @@ -0,0 +1,8 @@ +[](https://godoc.org/github.com/jackc/chunkreader) +[](https://travis-ci.org/jackc/chunkreader) + +# chunkreader + +Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations. + +Extracted from original implementation in https://github.com/jackc/pgx. diff --git a/vendor/github.com/jackc/chunkreader/v2/chunkreader.go b/vendor/github.com/jackc/chunkreader/v2/chunkreader.go new file mode 100644 index 000000000..afea1c520 --- /dev/null +++ b/vendor/github.com/jackc/chunkreader/v2/chunkreader.go @@ -0,0 +1,104 @@ +// Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations. +package chunkreader + +import ( + "io" +) + +// ChunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and +// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually +// requested. The memory returned via Next is owned by the caller. This avoids the need for an additional copy. +// +// The downside of this approach is that a large buffer can be pinned in memory even if only a small slice is +// referenced. For example, an entire 4096 byte block could be pinned in memory by even a 1 byte slice. In these rare +// cases it would be advantageous to copy the bytes to another slice. +type ChunkReader struct { + r io.Reader + + buf []byte + rp, wp int // buf read position and write position + + config Config +} + +// Config contains configuration parameters for ChunkReader. +type Config struct { + MinBufLen int // Minimum buffer length +} + +// New creates and returns a new ChunkReader for r with default configuration. +func New(r io.Reader) *ChunkReader { + cr, err := NewConfig(r, Config{}) + if err != nil { + panic("default config can't be bad") + } + + return cr +} + +// NewConfig creates and a new ChunkReader for r configured by config. +func NewConfig(r io.Reader, config Config) (*ChunkReader, error) { + if config.MinBufLen == 0 { + // By historical reasons Postgres currently has 8KB send buffer inside, + // so here we want to have at least the same size buffer. + // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 + // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru + config.MinBufLen = 8192 + } + + return &ChunkReader{ + r: r, + buf: make([]byte, config.MinBufLen), + config: config, + }, nil +} + +// Next returns buf filled with the next n bytes. The caller gains ownership of buf. It is not necessary to make a copy +// of buf. If an error occurs, buf will be nil. +func (r *ChunkReader) Next(n int) (buf []byte, err error) { + // n bytes already in buf + if (r.wp - r.rp) >= n { + buf = r.buf[r.rp : r.rp+n] + r.rp += n + return buf, err + } + + // available space in buf is less than n + if len(r.buf) < n { + r.copyBufContents(r.newBuf(n)) + } + + // buf is large enough, but need to shift filled area to start to make enough contiguous space + minReadCount := n - (r.wp - r.rp) + if (len(r.buf) - r.wp) < minReadCount { + newBuf := r.newBuf(n) + r.copyBufContents(newBuf) + } + + if err := r.appendAtLeast(minReadCount); err != nil { + return nil, err + } + + buf = r.buf[r.rp : r.rp+n] + r.rp += n + return buf, nil +} + +func (r *ChunkReader) appendAtLeast(fillLen int) error { + n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) + r.wp += n + return err +} + +func (r *ChunkReader) newBuf(size int) []byte { + if size < r.config.MinBufLen { + size = r.config.MinBufLen + } + return make([]byte, size) +} + +func (r *ChunkReader) copyBufContents(dest []byte) { + r.wp = copy(dest, r.buf[r.rp:r.wp]) + r.rp = 0 + r.buf = dest +} diff --git a/vendor/github.com/jackc/chunkreader/v2/go.mod b/vendor/github.com/jackc/chunkreader/v2/go.mod new file mode 100644 index 000000000..a1384b407 --- /dev/null +++ b/vendor/github.com/jackc/chunkreader/v2/go.mod @@ -0,0 +1,3 @@ +module github.com/jackc/chunkreader/v2 + +go 1.12 diff --git a/vendor/github.com/jackc/pgconn/.gitignore b/vendor/github.com/jackc/pgconn/.gitignore new file mode 100644 index 000000000..e980f5555 --- /dev/null +++ b/vendor/github.com/jackc/pgconn/.gitignore @@ -0,0 +1,3 @@ +.envrc +vendor/ +.vscode diff --git a/vendor/github.com/jackc/pgconn/CHANGELOG.md b/vendor/github.com/jackc/pgconn/CHANGELOG.md new file mode 100644 index 000000000..45c02f1e9 --- /dev/null +++ b/vendor/github.com/jackc/pgconn/CHANGELOG.md @@ -0,0 +1,122 @@ +# 1.10.0 (July 24, 2021) + +* net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned. + +# 1.9.0 (July 10, 2021) + +* pgconn.Timeout only is true for errors originating in pgconn (Michael Darr) +* Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle) +* Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard) +* Fix default host when parsing URL without host but with port +* Allow dbname query parameter in URL conn string +* Update underlying dependencies + +# 1.8.1 (March 25, 2021) + +* Better connection string sanitization (ip.novikov) +* Use proper pgpass location on Windows (Moshe Katz) +* Use errors instead of golang.org/x/xerrors +* Resume fallback on server error in Connect (Andrey Borodin) + +# 1.8.0 (December 3, 2020) + +* Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes) + +# 1.7.2 (November 3, 2020) + +* Fix data value slices into work buffer with capacities larger than length. + +# 1.7.1 (October 31, 2020) + +* Do not asyncClose after receiving FATAL error from PostgreSQL server + +# 1.7.0 (September 26, 2020) + +* Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded +* Add ReceiveResults (Sebastiaan Mannem) +* Fix parsing DSN connection with bad backslash +* Add PgConn.CleanupDone so connection pools can determine when async close is complete + +# 1.6.4 (July 29, 2020) + +* Fix deadlock on error after CommandComplete but before ReadyForQuery +* Fix panic on parsing DSN with trailing '=' + +# 1.6.3 (July 22, 2020) + +* Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo) + +# 1.6.2 (July 14, 2020) + +* Update pgservicefile library + +# 1.6.1 (June 27, 2020) + +* Update golang.org/x/crypto to latest +* Update golang.org/x/text to 0.3.3 +* Fix error handling for bad PGSERVICE definition +* Redact passwords in ParseConfig errors (Lukas Vogel) + +# 1.6.0 (June 6, 2020) + +* Fix panic when closing conn during cancellable query +* Fix behavior of sslmode=require with sslrootcert present (Petr Jediný) +* Fix field descriptions available after command concluded (Tobias Salzmann) +* Support connect_timeout (georgysavva) +* Handle IPv6 in connection URLs (Lukas Vogel) +* Fix ValidateConnect with cancelable context +* Improve CopyFrom performance +* Add Config.Copy (georgysavva) + +# 1.5.0 (March 30, 2020) + +* Update golang.org/x/crypto for security fix +* Implement "verify-ca" SSL mode (Greg Curtis) + +# 1.4.0 (March 7, 2020) + +* Fix ExecParams and ExecPrepared handling of empty query. +* Support reading config from PostgreSQL service files. + +# 1.3.2 (February 14, 2020) + +* Update chunkreader to v2.0.1 for optimized default buffer size. + +# 1.3.1 (February 5, 2020) + +* Fix CopyFrom deadlock when multiple NoticeResponse received during copy + +# 1.3.0 (January 23, 2020) + +* Add Hijack and Construct. +* Update pgproto3 to v2.0.1. + +# 1.2.1 (January 13, 2020) + +* Fix data race in context cancellation introduced in v1.2.0. + +# 1.2.0 (January 11, 2020) + +## Features + +* Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag. +* Add PgError.SQLState method. This could be used for compatibility with other drivers and databases. + +## Performance + +* Improve performance when context.Background() is used. (bakape) +* CommandTag.RowsAffected is faster and does not allocate. + +## Fixes + +* Try to cancel any in-progress query when a conn is closed by ctx cancel. +* Handle NoticeResponse during CopyFrom. +* Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish. + +# 1.1.0 (October 12, 2019) + +* Add PgConn.IsBusy() method. + +# 1.0.1 (September 19, 2019) + +* Fix statement cache not properly cleaning discarded statements. diff --git a/vendor/github.com/jackc/pgconn/LICENSE b/vendor/github.com/jackc/pgconn/LICENSE new file mode 100644 index 000000000..aebadd6c4 --- /dev/null +++ b/vendor/github.com/jackc/pgconn/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2019-2021 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/jackc/pgconn/README.md b/vendor/github.com/jackc/pgconn/README.md new file mode 100644 index 000000000..1c698a118 --- /dev/null +++ b/vendor/github.com/jackc/pgconn/README.md @@ -0,0 +1,56 @@ +[](https://godoc.org/github.com/jackc/pgconn) + + +# pgconn + +Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq. +It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx. +Applications should handle normal queries with a higher level library and only use pgconn directly when required for +low-level access to PostgreSQL functionality. + +## Example Usage + +```go +pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) +if err != nil { + log.Fatalln("pgconn failed to connect:", err) +} +defer pgConn.Close(context.Background()) + +result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +for result.NextRow() { + fmt.Println("User 123 has email:", string(result.Values()[0])) +} +_, err = result.Close() +if err != nil { + log.Fatalln("failed reading result:", err) +} +``` + +## Testing + +The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING` +environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*` +environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify +environment variable handling. + +### Example Test Environment + +Connect to your PostgreSQL server and run: + +``` +create database pgx_test; +``` + +Now you can run the tests: + +```bash +PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./... +``` + +### Connection and Authentication Tests + +Pgconn supports multiple connection types and means of authentication. These tests are optional. They +will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being +skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change +authentication code. diff --git a/vendor/github.com/jackc/pgconn/auth_scram.go b/vendor/github.com/jackc/pgconn/auth_scram.go new file mode 100644 index 000000000..6a143fcdc --- /dev/null +++ b/vendor/github.com/jackc/pgconn/auth_scram.go @@ -0,0 +1,266 @@ +// SCRAM-SHA-256 authentication +// +// Resources: +// https://tools.ietf.org/html/rfc5802 +// https://tools.ietf.org/html/rfc8265 +// https://www.postgresql.org/docs/current/sasl-authentication.html +// +// Inspiration drawn from other implementations: +// https://github.com/lib/pq/pull/608 +// https://github.com/lib/pq/pull/788 +// https://github.com/lib/pq/pull/833 + +package pgconn + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "strconv" + + "github.com/jackc/pgproto3/v2" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/text/secure/precis" +) + +const clientNonceLen = 18 + +// Perform SCRAM authentication. +func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { + sc, err := newScramClient(serverAuthMechanisms, c.config.Password) + if err != nil { + return err + } + + // Send client-first-message in a SASLInitialResponse + saslInitialResponse := &pgproto3.SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: sc.clientFirstMessage(), + } + _, err = c.conn.Write(saslInitialResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-first-message payload in a AuthenticationSASLContinue. + saslContinue, err := c.rxSASLContinue() + if err != nil { + return err + } + err = sc.recvServerFirstMessage(saslContinue.Data) + if err != nil { + return err + } + + // Send client-final-message in a SASLResponse + saslResponse := &pgproto3.SASLResponse{ + Data: []byte(sc.clientFinalMessage()), + } + _, err = c.conn.Write(saslResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-final-message payload in a AuthenticationSASLFinal. + saslFinal, err := c.rxSASLFinal() + if err != nil { + return err + } + return sc.recvServerFinalMessage(saslFinal.Data) +} + +func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue) + if ok { + return saslContinue, nil + } + + return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message") +} + +func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal) + if ok { + return saslFinal, nil + } + + return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message") +} + +type scramClient struct { + serverAuthMechanisms []string + password []byte + clientNonce []byte + + clientFirstMessageBare []byte + + serverFirstMessage []byte + clientAndServerNonce []byte + salt []byte + iterations int + + saltedPassword []byte + authMessage []byte +} + +func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { + sc := &scramClient{ + serverAuthMechanisms: serverAuthMechanisms, + } + + // Ensure server supports SCRAM-SHA-256 + hasScramSHA256 := false + for _, mech := range sc.serverAuthMechanisms { + if mech == "SCRAM-SHA-256" { + hasScramSHA256 = true + break + } + } + if !hasScramSHA256 { + return nil, errors.New("server does not support SCRAM-SHA-256") + } + + // precis.OpaqueString is equivalent to SASLprep for password. + var err error + sc.password, err = precis.OpaqueString.Bytes([]byte(password)) + if err != nil { + // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. + sc.password = []byte(password) + } + + buf := make([]byte, clientNonceLen) + _, err = rand.Read(buf) + if err != nil { + return nil, err + } + sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf))) + base64.RawStdEncoding.Encode(sc.clientNonce, buf) + + return sc, nil +} + +func (sc *scramClient) clientFirstMessage() []byte { + sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) + return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) +} + +func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { + sc.serverFirstMessage = serverFirstMessage + buf := serverFirstMessage + if !bytes.HasPrefix(buf, []byte("r=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include r=") + } + buf = buf[2:] + + idx := bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + sc.clientAndServerNonce = buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("s=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + buf = buf[2:] + + idx = bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + saltStr := buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("i=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + buf = buf[2:] + iterationsStr := buf + + var err error + sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) + if err != nil { + return fmt.Errorf("invalid SCRAM salt received from server: %w", err) + } + + sc.iterations, err = strconv.Atoi(string(iterationsStr)) + if err != nil || sc.iterations <= 0 { + return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err) + } + + if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not start with client nonce") + } + + if len(sc.clientAndServerNonce) <= len(sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not include server nonce") + } + + return nil +} + +func (sc *scramClient) clientFinalMessage() string { + clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) + + sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) + sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) + + clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) + + return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) +} + +func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { + if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) { + return errors.New("invalid SCRAM server-final-message received from server") + } + + serverSignature := serverFinalMessage[2:] + + if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { + return errors.New("invalid SCRAM ServerSignature received from server") + } + + return nil +} + +func computeHMAC(key, msg []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(msg) + return mac.Sum(nil) +} + +func computeClientProof(saltedPassword, authMessage []byte) []byte { + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + storedKey := sha256.Sum256(clientKey) + clientSignature := computeHMAC(storedKey[:], authMessage) + + clientProof := make([]byte, len(clientSignature)) + for i := 0; i < len(clientSignature); i++ { + clientProof[i] = clientKey[i] ^ clientSignature[i] + } + + buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) + base64.StdEncoding.Encode(buf, clientProof) + return buf +} + +func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + serverSignature := computeHMAC(serverKey, authMessage) + buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) + base64.StdEncoding.Encode(buf, serverSignature) + return buf +} diff --git a/vendor/github.com/jackc/pgconn/config.go b/vendor/github.com/jackc/pgconn/config.go new file mode 100644 index 000000000..172e7478b --- /dev/null +++ b/vendor/github.com/jackc/pgconn/config.go @@ -0,0 +1,729 @@ +package pgconn + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "io/ioutil" + "math" + "net" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/jackc/chunkreader/v2" + "github.com/jackc/pgpassfile" + "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgservicefile" +) + +type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error +type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error + +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A +// manually initialized Config will cause ConnectConfig to panic. +type Config struct { + Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + ConnectTimeout time.Duration + DialFunc DialFunc // e.g. net.Dialer.DialContext + LookupFunc LookupFunc // e.g. net.Resolver.LookupHost + BuildFrontend BuildFrontendFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + Fallbacks []*FallbackConfig + + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. + // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next + // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. + ValidateConnect ValidateConnectFunc + + // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables + // or prepare statements). If this returns an error the connection attempt fails. + AfterConnect AfterConnectFunc + + // OnNotice is a callback function called when a notice response is received. + OnNotice NoticeHandler + + // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. + OnNotification NotificationHandler + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. +} + +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the TLSConfig field: +// according to the tls.Config docs it must not be modified after creation. +func (c *Config) Copy() *Config { + newConf := new(Config) + *newConf = *c + if newConf.TLSConfig != nil { + newConf.TLSConfig = c.TLSConfig.Clone() + } + if newConf.RuntimeParams != nil { + newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) + for k, v := range c.RuntimeParams { + newConf.RuntimeParams[k] = v + } + } + if newConf.Fallbacks != nil { + newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) + for i, fallback := range c.Fallbacks { + newFallback := new(FallbackConfig) + *newFallback = *fallback + if newFallback.TLSConfig != nil { + newFallback.TLSConfig = fallback.TLSConfig.Clone() + } + newConf.Fallbacks[i] = newFallback + } + } + return newConf +} + +// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a +// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. +type FallbackConfig struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + TLSConfig *tls.Config // nil disables TLS +} + +// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with +// net.Dial. +func NetworkAddress(host string, port uint16) (network, address string) { + if strings.HasPrefix(host, "/") { + network = "unix" + address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) + } else { + network = "tcp" + address = net.JoinHostPort(host, strconv.Itoa(int(port))) + } + return network, address +} + +// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same +// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches +// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be +// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. +// +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca +// +// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done +// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be +// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should +// not be modified individually. They should all be modified or all left unchanged. +// +// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated +// values that will be tried in order. This can be used as part of a high availability system. See +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb +// +// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed +// via database URL or DSN: +// +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSERVICE +// PGSERVICEFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT +// PGAPPNAME +// PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS +// +// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. +// +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are +// usually but not always the environment variable name downcased and without the "PG" prefix. +// +// Important Security Notes: +// +// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if +// not set. +// +// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of +// security each sslmode provides. +// +// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of +// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of +// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback +// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually +// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting +// TLCConfig. +// +// Other known differences with libpq: +// +// If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first. +// +// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn +// does not. +// +// In addition, ParseConfig accepts the following options: +// +// min_read_buffer_size +// The minimum size of the internal read buffer. Default 8192. +// servicefile +// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a +// part of the connection string. +func ParseConfig(connString string) (*Config, error) { + defaultSettings := defaultSettings() + envSettings := parseEnvSettings() + + connStringSettings := make(map[string]string) + if connString != "" { + var err error + // connString may be a database URL or a DSN + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { + connStringSettings, err = parseURLSettings(connString) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} + } + } else { + connStringSettings, err = parseDSNSettings(connString) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} + } + } + } + + settings := mergeSettings(defaultSettings, envSettings, connStringSettings) + if service, present := settings["service"]; present { + serviceSettings, err := parseServiceSettings(settings["servicefile"], service) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} + } + + settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) + } + + minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} + } + + config := &Config{ + createdByParseConfig: true, + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), + } + + if connectTimeoutSetting, present := settings["connect_timeout"]; present { + connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} + } + config.ConnectTimeout = connectTimeout + config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) + } else { + defaultDialer := makeDefaultDialer() + config.DialFunc = defaultDialer.DialContext + } + + config.LookupFunc = makeDefaultResolver().LookupHost + + notRuntimeParams := map[string]struct{}{ + "host": struct{}{}, + "port": struct{}{}, + "database": struct{}{}, + "user": struct{}{}, + "password": struct{}{}, + "passfile": struct{}{}, + "connect_timeout": struct{}{}, + "sslmode": struct{}{}, + "sslkey": struct{}{}, + "sslcert": struct{}{}, + "sslrootcert": struct{}{}, + "target_session_attrs": struct{}{}, + "min_read_buffer_size": struct{}{}, + "service": struct{}{}, + "servicefile": struct{}{}, + } + + for k, v := range settings { + if _, present := notRuntimeParams[k]; present { + continue + } + config.RuntimeParams[k] = v + } + + fallbacks := []*FallbackConfig{} + + hosts := strings.Split(settings["host"], ",") + ports := strings.Split(settings["port"], ",") + + for i, host := range hosts { + var portStr string + if i < len(ports) { + portStr = ports[i] + } else { + portStr = ports[0] + } + + port, err := parsePort(portStr) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := NetworkAddress(host, port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configTLS(settings, host) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} + } + } + + for _, tlsConfig := range tlsConfigs { + fallbacks = append(fallbacks, &FallbackConfig{ + Host: host, + Port: port, + TLSConfig: tlsConfig, + }) + } + } + + config.Host = fallbacks[0].Host + config.Port = fallbacks[0].Port + config.TLSConfig = fallbacks[0].TLSConfig + config.Fallbacks = fallbacks[1:] + + passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) + if err == nil { + if config.Password == "" { + host := config.Host + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + host = "localhost" + } + + config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) + } + } + + if settings["target_session_attrs"] == "read-write" { + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite + } else if settings["target_session_attrs"] != "any" { + return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])} + } + + return config, nil +} + +func mergeSettings(settingSets ...map[string]string) map[string]string { + settings := make(map[string]string) + + for _, s2 := range settingSets { + for k, v := range s2 { + settings[k] = v + } + } + + return settings +} + +func parseEnvSettings() map[string]string { + settings := make(map[string]string) + + nameMap := map[string]string{ + "PGHOST": "host", + "PGPORT": "port", + "PGDATABASE": "database", + "PGUSER": "user", + "PGPASSWORD": "password", + "PGPASSFILE": "passfile", + "PGAPPNAME": "application_name", + "PGCONNECT_TIMEOUT": "connect_timeout", + "PGSSLMODE": "sslmode", + "PGSSLKEY": "sslkey", + "PGSSLCERT": "sslcert", + "PGSSLROOTCERT": "sslrootcert", + "PGTARGETSESSIONATTRS": "target_session_attrs", + "PGSERVICE": "service", + "PGSERVICEFILE": "servicefile", + } + + for envname, realname := range nameMap { + value := os.Getenv(envname) + if value != "" { + settings[realname] = value + } + } + + return settings +} + +func parseURLSettings(connString string) (map[string]string, error) { + settings := make(map[string]string) + + url, err := url.Parse(connString) + if err != nil { + return nil, err + } + + if url.User != nil { + settings["user"] = url.User.Username() + if password, present := url.User.Password(); present { + settings["password"] = password + } + } + + // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. + var hosts []string + var ports []string + for _, host := range strings.Split(url.Host, ",") { + if host == "" { + continue + } + if isIPOnly(host) { + hosts = append(hosts, strings.Trim(host, "[]")) + continue + } + h, p, err := net.SplitHostPort(host) + if err != nil { + return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) + } + if h != "" { + hosts = append(hosts, h) + } + if p != "" { + ports = append(ports, p) + } + } + if len(hosts) > 0 { + settings["host"] = strings.Join(hosts, ",") + } + if len(ports) > 0 { + settings["port"] = strings.Join(ports, ",") + } + + database := strings.TrimLeft(url.Path, "/") + if database != "" { + settings["database"] = database + } + + nameMap := map[string]string{ + "dbname": "database", + } + + for k, v := range url.Query() { + if k2, present := nameMap[k]; present { + k = k2 + } + + settings[k] = v[0] + } + + return settings, nil +} + +func isIPOnly(host string) bool { + return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":") +} + +var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} + +func parseDSNSettings(s string) (map[string]string, error) { + settings := make(map[string]string) + + nameMap := map[string]string{ + "dbname": "database", + } + + for len(s) > 0 { + var key, val string + eqIdx := strings.IndexRune(s, '=') + if eqIdx < 0 { + return nil, errors.New("invalid dsn") + } + + key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") + s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") + if len(s) == 0 { + } else if s[0] != '\'' { + end := 0 + for ; end < len(s); end++ { + if asciiSpace[s[end]] == 1 { + break + } + if s[end] == '\\' { + end++ + if end == len(s) { + return nil, errors.New("invalid backslash") + } + } + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } else { // quoted string + s = s[1:] + end := 0 + for ; end < len(s); end++ { + if s[end] == '\'' { + break + } + if s[end] == '\\' { + end++ + } + } + if end == len(s) { + return nil, errors.New("unterminated quoted string in connection info string") + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } + + if k, ok := nameMap[key]; ok { + key = k + } + + if key == "" { + return nil, errors.New("invalid dsn") + } + + settings[key] = val + } + + return settings, nil +} + +func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { + servicefile, err := pgservicefile.ReadServicefile(servicefilePath) + if err != nil { + return nil, fmt.Errorf("failed to read service file: %v", servicefilePath) + } + + service, err := servicefile.GetService(serviceName) + if err != nil { + return nil, fmt.Errorf("unable to find service: %v", serviceName) + } + + nameMap := map[string]string{ + "dbname": "database", + } + + settings := make(map[string]string, len(service.Settings)) + for k, v := range service.Settings { + if k2, present := nameMap[k]; present { + k = k2 + } + settings[k] = v + } + + return settings, nil +} + +// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is +// necessary to allow returning multiple TLS configs as sslmode "allow" and +// "prefer" allow fallback. +func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) { + host := thisHost + sslmode := settings["sslmode"] + sslrootcert := settings["sslrootcert"] + sslcert := settings["sslcert"] + sslkey := settings["sslkey"] + + // Match libpq default behavior + if sslmode == "" { + sslmode = "prefer" + } + + tlsConfig := &tls.Config{} + + switch sslmode { + case "disable": + return []*tls.Config{nil}, nil + case "allow", "prefer": + tlsConfig.InsecureSkipVerify = true + case "require": + // According to PostgreSQL documentation, if a root CA file exists, + // the behavior of sslmode=require should be the same as that of verify-ca + // + // See https://www.postgresql.org/docs/12/libpq-ssl.html + if sslrootcert != "" { + goto nextCase + } + tlsConfig.InsecureSkipVerify = true + break + nextCase: + fallthrough + case "verify-ca": + // Don't perform the default certificate verification because it + // will verify the hostname. Instead, verify the server's + // certificate chain ourselves in VerifyPeerCertificate and + // ignore the server name. This emulates libpq's verify-ca + // behavior. + // + // See https://github.com/golang/go/issues/21971#issuecomment-332693931 + // and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate + // for more info. + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return errors.New("failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + // Leave DNSName empty to skip hostname verification. + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + Intermediates: x509.NewCertPool(), + } + // Skip the first cert because it's the leaf. All others + // are intermediates. + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := certs[0].Verify(opts) + return err + } + case "verify-full": + tlsConfig.ServerName = host + default: + return nil, errors.New("sslmode is invalid") + } + + if sslrootcert != "" { + caCertPool := x509.NewCertPool() + + caPath := sslrootcert + caCert, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("unable to read CA file: %w", err) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.New("unable to add CA to cert pool") + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return nil, errors.New(`both "sslcert" and "sslkey" are required`) + } + + if sslcert != "" && sslkey != "" { + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return nil, fmt.Errorf("unable to read cert: %w", err) + } + + tlsConfig.Certificates = []tls.Certificate{cert} + } + + switch sslmode { + case "allow": + return []*tls.Config{nil, tlsConfig}, nil + case "prefer": + return []*tls.Config{tlsConfig, nil}, nil + case "require", "verify-ca", "verify-full": + return []*tls.Config{tlsConfig}, nil + default: + panic("BUG: bad sslmode should already have been caught") + } +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +func makeDefaultDialer() *net.Dialer { + return &net.Dialer{KeepAlive: 5 * time.Minute} +} + +func makeDefaultResolver() *net.Resolver { + return net.DefaultResolver +} + +func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { + return func(r io.Reader, w io.Writer) Frontend { + cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) + if err != nil { + panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) + } + frontend := pgproto3.NewFrontend(cr, w) + + return frontend + } +} + +func parseConnectTimeoutSetting(s string) (time.Duration, error) { + timeout, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, err + } + if timeout < 0 { + return 0, errors.New("negative timeout") + } + return time.Duration(timeout) * time.Second, nil +} + +func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { + d := makeDefaultDialer() + d.Timeout = timeout + return d.DialContext +} + +// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible +// target_session_attrs=read-write. +func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { + result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() + if result.Err != nil { + return result.Err + } + + if string(result.Rows[0][0]) == "on" { + return errors.New("read only connection") + } + + return nil +} diff --git a/vendor/github.com/jackc/pgconn/defaults.go b/vendor/github.com/jackc/pgconn/defaults.go new file mode 100644 index 000000000..f69cad317 --- /dev/null +++ b/vendor/github.com/jackc/pgconn/defaults.go @@ -0,0 +1,64 @@ +// +build !windows + +package pgconn + +import ( + "os" + "os/user" + "path/filepath" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + if err == nil { + settings["user"] = user.Username + settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } + } + + settings["target_session_attrs"] = "any" + + settings["min_read_buffer_size"] = "8192" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + candidatePaths := []string{ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX - homebrew + "/tmp", // standard PostgreSQL + } + + for _, path := range candidatePaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "localhost" +} diff --git a/vendor/github.com/jackc/pgconn/defaults_windows.go b/vendor/github.com/jackc/pgconn/defaults_windows.go new file mode 100644 index 000000000..71eb77dba --- /dev/null +++ b/vendor/github.com/jackc/pgconn/defaults_windows.go @@ -0,0 +1,59 @@ +package pgconn + +import ( + "os" + "os/user" + "path/filepath" + "strings" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + appData := os.Getenv("APPDATA") + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + username := user.Username + if strings.Contains(username, "\\") { + username = username[strings.LastIndex(username, "\\")+1:] + } + + settings["user"] = username + settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(appData, "postgresql", "postgresql.crt") + sslkey := filepath.Join(appData, "postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(appData, "postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } + } + + settings["target_session_attrs"] = "any" + + settings["min_read_buffer_size"] = "8192" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + return "localhost" +} diff --git a/vendor/github.com/jackc/pgconn/doc.go b/vendor/github.com/jackc/pgconn/doc.go new file mode 100644 index 000000000..cde58cd89 --- /dev/null +++ b/vendor/github.com/jackc/pgconn/doc.go @@ -0,0 +1,29 @@ +// Package pgconn is a low-level PostgreSQL database driver. +/* +pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at +nearly the same level is the C library libpq. + +Establishing a Connection + +Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for +libpq style environment variables. + +Executing a Query + +ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method +reads all rows into memory. + +Executing Multiple Queries in a Single Round Trip + +Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query +result. The ReadAll method reads all query results into memory. + +Context Support + +All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the +method immediately returns. In most circumstances, this will close the underlying connection. + +The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the +client to abort. +*/ +package pgconn diff --git a/vendor/github.com/jackc/pgconn/errors.go b/vendor/github.com/jackc/pgconn/errors.go new file mode 100644 index 000000000..a32b29c92 --- /dev/null +++ b/vendor/github.com/jackc/pgconn/errors.go @@ -0,0 +1,221 @@ +package pgconn + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "regexp" + "strings" +) + +// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. +func SafeToRetry(err error) bool { + if e, ok := err.(interface{ SafeToRetry() bool }); ok { + return e.SafeToRetry() + } + return false +} + +// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a +// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. +func Timeout(err error) bool { + var timeoutErr *errTimeout + return errors.As(err, &timeoutErr) +} + +// PgError represents an error reported by the PostgreSQL server. See +// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for +// detailed field description. +type PgError struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string +} + +func (pe *PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" +} + +// SQLState returns the SQLState of the error. +func (pe *PgError) SQLState() string { + return pe.Code +} + +type connectError struct { + config *Config + msg string + err error +} + +func (e *connectError) Error() string { + sb := &strings.Builder{} + fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) + if e.err != nil { + fmt.Fprintf(sb, " (%s)", e.err.Error()) + } + return sb.String() +} + +func (e *connectError) Unwrap() error { + return e.err +} + +type connLockError struct { + status string +} + +func (e *connLockError) SafeToRetry() bool { + return true // a lock failure by definition happens before the connection is used. +} + +func (e *connLockError) Error() string { + return e.status +} + +type parseConfigError struct { + connString string + msg string + err error +} + +func (e *parseConfigError) Error() string { + connString := redactPW(e.connString) + if e.err == nil { + return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) + } + return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) +} + +func (e *parseConfigError) Unwrap() error { + return e.err +} + +// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == +// true. Otherwise returns err. +func preferContextOverNetTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { + return &errTimeout{err: ctx.Err()} + } + return err +} + +type pgconnError struct { + msg string + err error + safeToRetry bool +} + +func (e *pgconnError) Error() string { + if e.msg == "" { + return e.err.Error() + } + if e.err == nil { + return e.msg + } + return fmt.Sprintf("%s: %s", e.msg, e.err.Error()) +} + +func (e *pgconnError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *pgconnError) Unwrap() error { + return e.err +} + +// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is +// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true. +type errTimeout struct { + err error +} + +func (e *errTimeout) Error() string { + return fmt.Sprintf("timeout: %s", e.err.Error()) +} + +func (e *errTimeout) SafeToRetry() bool { + return SafeToRetry(e.err) +} + +func (e *errTimeout) Unwrap() error { + return e.err +} + +type contextAlreadyDoneError struct { + err error +} + +func (e *contextAlreadyDoneError) Error() string { + return fmt.Sprintf("context already done: %s", e.err.Error()) +} + +func (e *contextAlreadyDoneError) SafeToRetry() bool { + return true +} + +func (e *contextAlreadyDoneError) Unwrap() error { + return e.err +} + +// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`. +func newContextAlreadyDoneError(ctx context.Context) (err error) { + return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}} +} + +type writeError struct { + err error + safeToRetry bool +} + +func (e *writeError) Error() string { + return fmt.Sprintf("write failed: %s", e.err.Error()) +} + +func (e *writeError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *writeError) Unwrap() error { + return e.err +} + +func redactPW(connString string) string { + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { + if u, err := url.Parse(connString); err == nil { + return redactURL(u) + } + } + quotedDSN := regexp.MustCompile(`password='[^']*'`) + connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + plainDSN := regexp.MustCompile(`password=[^ ]*`) + connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + brokenURL := regexp.MustCompile(`:[^:@]+?@`) + connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") + return connString +} + +func redactURL(u *url.URL) string { + if u == nil { + return "" + } + if _, pwSet := u.User.Password(); pwSet { + u.User = url.UserPassword(u.User.Username(), "xxxxx") + } + return u.String() +} diff --git a/vendor/github.com/jackc/pgconn/go.mod b/vendor/github.com/jackc/pgconn/go.mod new file mode 100644 index 000000000..6fdd0e979 --- /dev/null +++ b/vendor/github.com/jackc/pgconn/go.mod @@ -0,0 +1,15 @@ +module github.com/jackc/pgconn + +go 1.12 + +require ( + github.com/jackc/chunkreader/v2 v2.0.1 + github.com/jackc/pgio v1.0.0 + github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 + github.com/jackc/pgpassfile v1.0.0 + github.com/jackc/pgproto3/v2 v2.1.1 + github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b + github.com/stretchr/testify v1.7.0 + golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 + golang.org/x/text v0.3.6 +) diff --git a/vendor/github.com/jackc/pgconn/go.sum b/vendor/github.com/jackc/pgconn/go.sum new file mode 100644 index 000000000..3c77ee21b --- /dev/null +++ b/vendor/github.com/jackc/pgconn/go.sum @@ -0,0 +1,130 @@ +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/jackc/pgconn/internal/ctxwatch/context_watcher.go b/vendor/github.com/jackc/pgconn/internal/ctxwatch/context_watcher.go new file mode 100644 index 000000000..391f0b791 --- /dev/null +++ b/vendor/github.com/jackc/pgconn/internal/ctxwatch/context_watcher.go @@ -0,0 +1,64 @@ +package ctxwatch + +import ( + "context" +) + +// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a +// time. +type ContextWatcher struct { + onCancel func() + onUnwatchAfterCancel func() + unwatchChan chan struct{} + watchInProgress bool + onCancelWasCalled bool +} + +// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. +// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and +// onCancel called. +func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { + cw := &ContextWatcher{ + onCancel: onCancel, + onUnwatchAfterCancel: onUnwatchAfterCancel, + unwatchChan: make(chan struct{}), + } + + return cw +} + +// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. +func (cw *ContextWatcher) Watch(ctx context.Context) { + if cw.watchInProgress { + panic("Watch already in progress") + } + + cw.onCancelWasCalled = false + + if ctx.Done() != nil { + cw.watchInProgress = true + go func() { + select { + case <-ctx.Done(): + cw.onCancel() + cw.onCancelWasCalled = true + <-cw.unwatchChan + case <-cw.unwatchChan: + } + }() + } else { + cw.watchInProgress = false + } +} + +// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was +// called then onUnwatchAfterCancel will also be called. +func (cw *ContextWatcher) Unwatch() { + if cw.watchInProgress { + cw.unwatchChan <- struct{}{} + if cw.onCancelWasCalled { + cw.onUnwatchAfterCancel() + } + cw.watchInProgress = false + } +} diff --git a/vendor/github.com/jackc/pgconn/pgconn.go b/vendor/github.com/jackc/pgconn/pgconn.go new file mode 100644 index 000000000..43b13e43a --- /dev/null +++ b/vendor/github.com/jackc/pgconn/pgconn.go @@ -0,0 +1,1724 @@ +package pgconn + +import ( + "context" + "crypto/md5" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "math" + "net" + "strings" + "sync" + "time" + + "github.com/jackc/pgconn/internal/ctxwatch" + "github.com/jackc/pgio" + "github.com/jackc/pgproto3/v2" +) + +const ( + connStatusUninitialized = iota + connStatusConnecting + connStatusClosed + connStatusIdle + connStatusBusy +) + +const wbufLen = 1024 + +// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from +// LISTEN/NOTIFY notification. +type Notice PgError + +// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system +type Notification struct { + PID uint32 // backend pid that sent the notification + Channel string // channel from which notification was received + Payload string +} + +// DialFunc is a function that can be used to connect to a PostgreSQL server. +type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + +// LookupFunc is a function that can be used to lookup IPs addrs from host. +type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) + +// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. +type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend + +// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at +// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin +// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY +// notification. +type NoticeHandler func(*PgConn, *Notice) + +// NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications +// can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is +// aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a +// notice event. +type NotificationHandler func(*PgConn, *Notification) + +// Frontend used to receive messages from backend. +type Frontend interface { + Receive() (pgproto3.BackendMessage, error) +} + +// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. +type PgConn struct { + conn net.Conn // the underlying TCP or unix domain socket connection + pid uint32 // backend pid + secretKey uint32 // key to use to send a cancel query message to the server + parameterStatuses map[string]string // parameters that have been reported by the server + txStatus byte + frontend Frontend + + config *Config + + status byte // One of connStatus* constants + + bufferingReceive bool + bufferingReceiveMux sync.Mutex + bufferingReceiveMsg pgproto3.BackendMessage + bufferingReceiveErr error + + peekedMsg pgproto3.BackendMessage + + // Reusable / preallocated resources + wbuf []byte // write buffer + resultReader ResultReader + multiResultReader MultiResultReader + contextWatcher *ctxwatch.ContextWatcher + + cleanupDone chan struct{} +} + +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) +// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt. +func Connect(ctx context.Context, connString string) (*PgConn, error) { + config, err := ParseConfig(connString) + if err != nil { + return nil, err + } + + return ConnectConfig(ctx, config) +} + +// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with +// ParseConfig. ctx can be used to cancel a connect attempt. +// +// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An +// authentication error will terminate the chain of attempts (like libpq: +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, +// if all attempts fail the last error is returned. +func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") + } + + // ConnectTimeout restricts the whole connection process. + if config.ConnectTimeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout) + defer cancel() + } + // Simplify usage by treating primary config and fallbacks the same. + fallbackConfigs := []*FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + } + fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + + fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) + if err != nil { + return nil, &connectError{config: config, msg: "hostname resolving error", err: err} + } + + if len(fallbackConfigs) == 0 { + return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} + } + + for _, fc := range fallbackConfigs { + pgConn, err = connect(ctx, config, fc) + if err == nil { + break + } else if pgerr, ok := err.(*PgError); ok { + err = &connectError{config: config, msg: "server error", err: pgerr} + ERRCODE_INVALID_PASSWORD := "28P01" // worng password + ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION := "28000" // db does not exist + if pgerr.Code == ERRCODE_INVALID_PASSWORD || pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION { + break + } + } + } + + if err != nil { + return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError + } + + if config.AfterConnect != nil { + err := config.AfterConnect(ctx, pgConn) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "AfterConnect error", err: err} + } + } + + return pgConn, nil +} + +func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { + var configs []*FallbackConfig + + for _, fb := range fallbacks { + // skip resolve for unix sockets + if strings.HasPrefix(fb.Host, "/") { + configs = append(configs, &FallbackConfig{ + Host: fb.Host, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + + continue + } + + ips, err := lookupFn(ctx, fb.Host) + if err != nil { + return nil, err + } + + for _, ip := range ips { + configs = append(configs, &FallbackConfig{ + Host: ip, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + } + } + + return configs, nil +} + +func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { + pgConn := new(PgConn) + pgConn.config = config + pgConn.wbuf = make([]byte, 0, wbufLen) + pgConn.cleanupDone = make(chan struct{}) + + var err error + network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) + pgConn.conn, err = config.DialFunc(ctx, network, address) + if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + err = &errTimeout{err: err} + } + return nil, &connectError{config: config, msg: "dial error", err: err} + } + + pgConn.parameterStatuses = make(map[string]string) + + if fallbackConfig.TLSConfig != nil { + if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "tls error", err: err} + } + } + + pgConn.status = connStatusConnecting + pgConn.contextWatcher = ctxwatch.NewContextWatcher( + func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { pgConn.conn.SetDeadline(time.Time{}) }, + ) + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) + + startupMsg := pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: make(map[string]string), + } + + // Copy default run-time params + for k, v := range config.RuntimeParams { + startupMsg.Parameters[k] = v + } + + startupMsg.Parameters["user"] = config.User + if config.Database != "" { + startupMsg.Parameters["database"] = config.Database + } + + if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write startup message", err: err} + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.conn.Close() + if err, ok := err.(*PgError); ok { + return nil, err + } + return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)} + } + + switch msg := msg.(type) { + case *pgproto3.BackendKeyData: + pgConn.pid = msg.ProcessID + pgConn.secretKey = msg.SecretKey + + case *pgproto3.AuthenticationOk: + case *pgproto3.AuthenticationCleartextPassword: + err = pgConn.txPasswordMessage(pgConn.config.Password) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write password message", err: err} + } + case *pgproto3.AuthenticationMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write password message", err: err} + } + case *pgproto3.AuthenticationSASL: + err = pgConn.scramAuth(msg.AuthMechanisms) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed SASL auth", err: err} + } + + case *pgproto3.ReadyForQuery: + pgConn.status = connStatusIdle + if config.ValidateConnect != nil { + // ValidateConnect may execute commands that cause the context to be watched again. Unwatch first to avoid + // the watch already in progress panic. This is that last thing done by this method so there is no need to + // restart the watch after ValidateConnect returns. + // + // See https://github.com/jackc/pgconn/issues/40. + pgConn.contextWatcher.Unwatch() + + err := config.ValidateConnect(ctx, pgConn) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} + } + } + return pgConn, nil + case *pgproto3.ParameterStatus: + // handled by ReceiveMessage + case *pgproto3.ErrorResponse: + pgConn.conn.Close() + return nil, ErrorResponseToPgError(msg) + default: + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "received unexpected message", err: err} + } + } +} + +func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { + err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103}) + if err != nil { + return + } + + response := make([]byte, 1) + if _, err = io.ReadFull(pgConn.conn, response); err != nil { + return + } + + if response[0] != 'S' { + return errors.New("server refused TLS connection") + } + + pgConn.conn = tls.Client(pgConn.conn, tlsConfig) + + return nil +} + +func (pgConn *PgConn) txPasswordMessage(password string) (err error) { + msg := &pgproto3.PasswordMessage{Password: password} + _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) + return err +} + +func hexMD5(s string) string { + hash := md5.New() + io.WriteString(hash, s) + return hex.EncodeToString(hash.Sum(nil)) +} + +func (pgConn *PgConn) signalMessage() chan struct{} { + if pgConn.bufferingReceive { + panic("BUG: signalMessage when already in progress") + } + + pgConn.bufferingReceive = true + pgConn.bufferingReceiveMux.Lock() + + ch := make(chan struct{}) + go func() { + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() + pgConn.bufferingReceiveMux.Unlock() + close(ch) + }() + + return ch +} + +// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as +// error to call SendBytes while reading the result of a query. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { + if err := pgConn.lock(); err != nil { + return err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + return &writeError{err: err, safeToRetry: n == 0} + } + + return nil +} + +// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the +// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages +// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger +// the OnNotification callback. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + msg, err := pgConn.receiveMessage() + if err != nil { + err = &pgconnError{ + msg: "receive message failed", + err: preferContextOverNetTimeoutError(ctx, err), + safeToRetry: true} + } + return msg, err +} + +// peekMessage peeks at the next message without setting up context cancellation. +func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { + if pgConn.peekedMsg != nil { + return pgConn.peekedMsg, nil + } + + var msg pgproto3.BackendMessage + var err error + if pgConn.bufferingReceive { + pgConn.bufferingReceiveMux.Lock() + msg = pgConn.bufferingReceiveMsg + err = pgConn.bufferingReceiveErr + pgConn.bufferingReceiveMux.Unlock() + pgConn.bufferingReceive = false + + // If a timeout error happened in the background try the read again. + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + msg, err = pgConn.frontend.Receive() + } + } else { + msg, err = pgConn.frontend.Receive() + } + + if err != nil { + // Close on anything other than timeout error - everything else is fatal + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { + pgConn.asyncClose() + } + + return nil, err + } + + pgConn.peekedMsg = msg + return msg, nil +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := pgConn.peekMessage() + if err != nil { + // Close on anything other than timeout error - everything else is fatal + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { + pgConn.asyncClose() + } + + return nil, err + } + pgConn.peekedMsg = nil + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + pgConn.txStatus = msg.TxStatus + case *pgproto3.ParameterStatus: + pgConn.parameterStatuses[msg.Name] = msg.Value + case *pgproto3.ErrorResponse: + if msg.Severity == "FATAL" { + pgConn.status = connStatusClosed + pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. + close(pgConn.cleanupDone) + return nil, ErrorResponseToPgError(msg) + } + case *pgproto3.NoticeResponse: + if pgConn.config.OnNotice != nil { + pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg)) + } + case *pgproto3.NotificationResponse: + if pgConn.config.OnNotification != nil { + pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) + } + } + + return msg, nil +} + +// Conn returns the underlying net.Conn. +func (pgConn *PgConn) Conn() net.Conn { + return pgConn.conn +} + +// PID returns the backend PID. +func (pgConn *PgConn) PID() uint32 { + return pgConn.pid +} + +// TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message. +// +// Possible return values: +// 'I' - idle / not in transaction +// 'T' - in a transaction +// 'E' - in a failed transaction +// +// See https://www.postgresql.org/docs/current/protocol-message-formats.html. +func (pgConn *PgConn) TxStatus() byte { + return pgConn.txStatus +} + +// SecretKey returns the backend secret key used to send a cancel query message to the server. +func (pgConn *PgConn) SecretKey() uint32 { + return pgConn.secretKey +} + +// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by +// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The +// underlying net.Conn.Close() will always be called regardless of any other errors. +func (pgConn *PgConn) Close(ctx context.Context) error { + if pgConn.status == connStatusClosed { + return nil + } + pgConn.status = connStatusClosed + + defer close(pgConn.cleanupDone) + defer pgConn.conn.Close() + + if ctx != context.Background() { + // Close may be called while a cancellable query is in progress. This will most often be triggered by panic when + // a defer closes the connection (possibly indirectly via a transaction or a connection pool). Unwatch to end any + // previous watch. It is safe to Unwatch regardless of whether a watch is already is progress. + // + // See https://github.com/jackc/pgconn/issues/29 + pgConn.contextWatcher.Unwatch() + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Ignore any errors sending Terminate message and waiting for server to close connection. + // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully + // ignores errors. + // + // See https://github.com/jackc/pgx/issues/637 + pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + pgConn.conn.Read(make([]byte, 1)) + + return pgConn.conn.Close() +} + +// asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying +// connection. +func (pgConn *PgConn) asyncClose() { + if pgConn.status == connStatusClosed { + return + } + pgConn.status = connStatusClosed + + go func() { + defer close(pgConn.cleanupDone) + defer pgConn.conn.Close() + + deadline := time.Now().Add(time.Second * 15) + + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + pgConn.CancelRequest(ctx) + + pgConn.conn.SetDeadline(deadline) + + pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + pgConn.conn.Read(make([]byte, 1)) + }() +} + +// CleanupDone returns a channel that will be closed after all underlying resources have been cleaned up. A closed +// connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing +// yet. This is because certain errors such as a context cancellation require that the interrupted function call return +// immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are +// closed asynchronously. +// +// This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while +// an old connection is still being cleaned up and thereby exceeding the maximum pool size. +func (pgConn *PgConn) CleanupDone() chan (struct{}) { + return pgConn.cleanupDone +} + +// IsClosed reports if the connection has been closed. +// +// CleanupDone() can be used to determine if all cleanup has been completed. +func (pgConn *PgConn) IsClosed() bool { + return pgConn.status < connStatusIdle +} + +// IsBusy reports if the connection is busy. +func (pgConn *PgConn) IsBusy() bool { + return pgConn.status == connStatusBusy +} + +// lock locks the connection. +func (pgConn *PgConn) lock() error { + switch pgConn.status { + case connStatusBusy: + return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug. + case connStatusClosed: + return &connLockError{status: "conn closed"} + case connStatusUninitialized: + return &connLockError{status: "conn uninitialized"} + } + pgConn.status = connStatusBusy + return nil +} + +func (pgConn *PgConn) unlock() { + switch pgConn.status { + case connStatusBusy: + pgConn.status = connStatusIdle + case connStatusClosed: + default: + panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. + } +} + +// ParameterStatus returns the value of a parameter reported by the server (e.g. +// server_version). Returns an empty string for unknown parameters. +func (pgConn *PgConn) ParameterStatus(key string) string { + return pgConn.parameterStatuses[key] +} + +// CommandTag is the result of an Exec function +type CommandTag []byte + +// RowsAffected returns the number of rows affected. If the CommandTag was not +// for a row affecting command (e.g. "CREATE TABLE") then it returns 0. +func (ct CommandTag) RowsAffected() int64 { + // Find last non-digit + idx := -1 + for i := len(ct) - 1; i >= 0; i-- { + if ct[i] >= '0' && ct[i] <= '9' { + idx = i + } else { + break + } + } + + if idx == -1 { + return 0 + } + + var n int64 + for _, b := range ct[idx:] { + n = n*10 + int64(b-'0') + } + + return n +} + +func (ct CommandTag) String() string { + return string(ct) +} + +// Insert is true if the command tag starts with "INSERT". +func (ct CommandTag) Insert() bool { + return len(ct) >= 6 && + ct[0] == 'I' && + ct[1] == 'N' && + ct[2] == 'S' && + ct[3] == 'E' && + ct[4] == 'R' && + ct[5] == 'T' +} + +// Update is true if the command tag starts with "UPDATE". +func (ct CommandTag) Update() bool { + return len(ct) >= 6 && + ct[0] == 'U' && + ct[1] == 'P' && + ct[2] == 'D' && + ct[3] == 'A' && + ct[4] == 'T' && + ct[5] == 'E' +} + +// Delete is true if the command tag starts with "DELETE". +func (ct CommandTag) Delete() bool { + return len(ct) >= 6 && + ct[0] == 'D' && + ct[1] == 'E' && + ct[2] == 'L' && + ct[3] == 'E' && + ct[4] == 'T' && + ct[5] == 'E' +} + +// Select is true if the command tag starts with "SELECT". +func (ct CommandTag) Select() bool { + return len(ct) >= 6 && + ct[0] == 'S' && + ct[1] == 'E' && + ct[2] == 'L' && + ct[3] == 'E' && + ct[4] == 'C' && + ct[5] == 'T' +} + +type StatementDescription struct { + Name string + SQL string + ParamOIDs []uint32 + Fields []pgproto3.FieldDescription +} + +// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This +// allows Prepare to also to describe statements without creating a server-side prepared statement. +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + buf := pgConn.wbuf + buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + return nil, &writeError{err: err, safeToRetry: n == 0} + } + + psd := &StatementDescription{Name: name, SQL: sql} + + var parseErr error + +readloop: + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) + case *pgproto3.RowDescription: + psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) + copy(psd.Fields, msg.Fields) + case *pgproto3.ErrorResponse: + parseErr = ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + break readloop + } + } + + if parseErr != nil { + return nil, parseErr + } + return psd, nil +} + +// ErrorResponseToPgError converts a wire protocol error message to a *PgError. +func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { + return &PgError{ + Severity: msg.Severity, + Code: string(msg.Code), + Message: string(msg.Message), + Detail: string(msg.Detail), + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: string(msg.InternalQuery), + Where: string(msg.Where), + SchemaName: string(msg.SchemaName), + TableName: string(msg.TableName), + ColumnName: string(msg.ColumnName), + DataTypeName: string(msg.DataTypeName), + ConstraintName: msg.ConstraintName, + File: string(msg.File), + Line: msg.Line, + Routine: string(msg.Routine), + } +} + +func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { + pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg)) + return (*Notice)(pgerr) +} + +// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 +func (pgConn *PgConn) CancelRequest(ctx context.Context) error { + // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing + // the connection config. This is important in high availability configurations where fallback connections may be + // specified or DNS may be used to load balance. + serverAddr := pgConn.conn.RemoteAddr() + cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + if err != nil { + return err + } + defer cancelConn.Close() + + if ctx != context.Background() { + contextWatcher := ctxwatch.NewContextWatcher( + func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { cancelConn.SetDeadline(time.Time{}) }, + ) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() + } + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return err + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return err + } + + return nil +} + +// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not +// received. +func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { + if err := pgConn.lock(); err != nil { + return err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return newContextAlreadyDoneError(ctx) + default: + } + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + switch msg.(type) { + case *pgproto3.NotificationResponse: + return nil + } + } +} + +// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is +// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control +// statements. +// +// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. +func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + buf := pgConn.wbuf + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + pgConn.contextWatcher.Unwatch() + multiResult.closed = true + multiResult.err = &writeError{err: err, safeToRetry: n == 0} + pgConn.unlock() + return multiResult + } + + return multiResult +} + +// ReceiveResults reads the result that might be returned by Postgres after a SendBytes +// (e.a. after sending a CopyDone in a copy-both situation). +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader { + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + return multiResult +} + +// ExecParams executes a command via the PostgreSQL extended query protocol. +// +// sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, +// etc. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for +// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. +// ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text format. +// +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + buf := pgConn.wbuf + buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + + pgConn.execExtendedSuffix(buf, result) + + return result +} + +// ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text format. +// +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + buf := pgConn.wbuf + buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + + pgConn.execExtendedSuffix(buf, result) + + return result +} + +func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { + pgConn.resultReader = ResultReader{ + pgConn: pgConn, + ctx: ctx, + } + result := &pgConn.resultReader + + if err := pgConn.lock(); err != nil { + result.concludeCommand(nil, err) + result.closed = true + return result + } + + if len(paramValues) > math.MaxUint16 { + result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.closed = true + pgConn.unlock() + return result + } + + if ctx != context.Background() { + select { + case <-ctx.Done(): + result.concludeCommand(nil, newContextAlreadyDoneError(ctx)) + result.closed = true + pgConn.unlock() + return result + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + return result +} + +func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { + buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) + buf = (&pgproto3.Execute{}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) + pgConn.contextWatcher.Unwatch() + result.closed = true + pgConn.unlock() + return + } + + result.readUntilRowDescription() +} + +// CopyTo executes the copy command sql and copies the results to w. +func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + + if ctx != context.Background() { + select { + case <-ctx.Done(): + pgConn.unlock() + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Send copy to command + buf := pgConn.wbuf + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + pgConn.unlock() + return nil, &writeError{err: err, safeToRetry: n == 0} + } + + // Read results + var commandTag CommandTag + var pgErr error + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyDone: + case *pgproto3.CopyData: + _, err := w.Write(msg.Data) + if err != nil { + pgConn.asyncClose() + return nil, err + } + case *pgproto3.ReadyForQuery: + pgConn.unlock() + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + } + } +} + +// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server. +// +// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r +// could still block. +func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + // Send copy to command + buf := pgConn.wbuf + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + return nil, &writeError{err: err, safeToRetry: n == 0} + } + + // Read until copy in response or error. + var commandTag CommandTag + var pgErr error + pendingCopyInResponse := true + for pendingCopyInResponse { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyInResponse: + pendingCopyInResponse = false + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + return commandTag, pgErr + } + } + + // Send copy data + abortCopyChan := make(chan struct{}) + copyErrChan := make(chan error, 1) + signalMessageChan := pgConn.signalMessage() + + go func() { + buf := make([]byte, 0, 65536) + buf = append(buf, 'd') + sp := len(buf) + + for { + n, readErr := r.Read(buf[5:cap(buf)]) + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) + + _, writeErr := pgConn.conn.Write(buf) + if writeErr != nil { + // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. + pgConn.conn.Close() + + copyErrChan <- writeErr + return + } + } + if readErr != nil { + copyErrChan <- readErr + return + } + + select { + case <-abortCopyChan: + return + default: + } + } + }() + + var copyErr error + for copyErr == nil && pgErr == nil { + select { + case copyErr = <-copyErrChan: + case <-signalMessageChan: + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + default: + signalMessageChan = pgConn.signalMessage() + } + } + } + close(abortCopyChan) + + buf = buf[:0] + if copyErr == io.EOF || pgErr != nil { + copyDone := &pgproto3.CopyDone{} + buf = copyDone.Encode(buf) + } else { + copyFail := &pgproto3.CopyFail{Message: copyErr.Error()} + buf = copyFail.Encode(buf) + } + _, err = pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + return nil, err + } + + // Read results + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return nil, preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = ErrorResponseToPgError(msg) + } + } +} + +// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. +type MultiResultReader struct { + pgConn *PgConn + ctx context.Context + + rr *ResultReader + + closed bool + err error +} + +// ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods. +func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { + var results []*Result + + for mrr.NextResult() { + results = append(results, mrr.ResultReader().Read()) + } + err := mrr.Close() + + return results, err +} + +func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := mrr.pgConn.receiveMessage() + + if err != nil { + mrr.pgConn.contextWatcher.Unwatch() + mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) + mrr.closed = true + mrr.pgConn.asyncClose() + return nil, mrr.err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + mrr.pgConn.contextWatcher.Unwatch() + mrr.closed = true + mrr.pgConn.unlock() + case *pgproto3.ErrorResponse: + mrr.err = ErrorResponseToPgError(msg) + } + + return msg, nil +} + +// NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. +func (mrr *MultiResultReader) NextResult() bool { + for !mrr.closed && mrr.err == nil { + msg, err := mrr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + mrr.pgConn.resultReader = ResultReader{ + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, + fieldDescriptions: msg.Fields, + } + mrr.rr = &mrr.pgConn.resultReader + return true + case *pgproto3.CommandComplete: + mrr.pgConn.resultReader = ResultReader{ + commandTag: CommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + mrr.rr = &mrr.pgConn.resultReader + return true + case *pgproto3.EmptyQueryResponse: + return false + } + } + + return false +} + +// ResultReader returns the current ResultReader. +func (mrr *MultiResultReader) ResultReader() *ResultReader { + return mrr.rr +} + +// Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use. +func (mrr *MultiResultReader) Close() error { + for !mrr.closed { + _, err := mrr.receiveMessage() + if err != nil { + return mrr.err + } + } + + return mrr.err +} + +// ResultReader is a reader for the result of a single query. +type ResultReader struct { + pgConn *PgConn + multiResultReader *MultiResultReader + ctx context.Context + + fieldDescriptions []pgproto3.FieldDescription + rowValues [][]byte + commandTag CommandTag + commandConcluded bool + closed bool + err error +} + +// Result is the saved query response that is returned by calling Read on a ResultReader. +type Result struct { + FieldDescriptions []pgproto3.FieldDescription + Rows [][][]byte + CommandTag CommandTag + Err error +} + +// Read saves the query response to a Result. +func (rr *ResultReader) Read() *Result { + br := &Result{} + + for rr.NextRow() { + if br.FieldDescriptions == nil { + br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions())) + copy(br.FieldDescriptions, rr.FieldDescriptions()) + } + + row := make([][]byte, len(rr.Values())) + copy(row, rr.Values()) + br.Rows = append(br.Rows, row) + } + + br.CommandTag, br.Err = rr.Close() + + return br +} + +// NextRow advances the ResultReader to the next row and returns true if a row is available. +func (rr *ResultReader) NextRow() bool { + for !rr.commandConcluded { + msg, err := rr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.DataRow: + rr.rowValues = msg.Values + return true + } + } + + return false +} + +// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until +// the ResultReader is closed. +func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription { + return rr.fieldDescriptions +} + +// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only +// valid until the next NextRow call or the ResultReader is closed. However, the underlying byte data is safe to +// retain a reference to and mutate. +func (rr *ResultReader) Values() [][]byte { + return rr.rowValues +} + +// Close consumes any remaining result data and returns the command tag or +// error. +func (rr *ResultReader) Close() (CommandTag, error) { + if rr.closed { + return rr.commandTag, rr.err + } + rr.closed = true + + for !rr.commandConcluded { + _, err := rr.receiveMessage() + if err != nil { + return nil, rr.err + } + } + + if rr.multiResultReader == nil { + for { + msg, err := rr.receiveMessage() + if err != nil { + return nil, rr.err + } + + switch msg := msg.(type) { + // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. + case *pgproto3.ErrorResponse: + rr.err = ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + rr.pgConn.contextWatcher.Unwatch() + rr.pgConn.unlock() + return rr.commandTag, rr.err + } + } + } + + return rr.commandTag, rr.err +} + +// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any +// error will be stored in the ResultReader. +func (rr *ResultReader) readUntilRowDescription() { + for !rr.commandConcluded { + // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. + // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are + // manually used to construct a query that does not issue a describe statement. + msg, _ := rr.pgConn.peekMessage() + if _, ok := msg.(*pgproto3.DataRow); ok { + return + } + + // Consume the message + msg, _ = rr.receiveMessage() + if _, ok := msg.(*pgproto3.RowDescription); ok { + return + } + } +} + +func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { + if rr.multiResultReader == nil { + msg, err = rr.pgConn.receiveMessage() + } else { + msg, err = rr.multiResultReader.receiveMessage() + } + + if err != nil { + err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.concludeCommand(nil, err) + rr.pgConn.contextWatcher.Unwatch() + rr.closed = true + if rr.multiResultReader == nil { + rr.pgConn.asyncClose() + } + + return nil, rr.err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rr.fieldDescriptions = msg.Fields + case *pgproto3.CommandComplete: + rr.concludeCommand(CommandTag(msg.CommandTag), nil) + case *pgproto3.EmptyQueryResponse: + rr.concludeCommand(nil, nil) + case *pgproto3.ErrorResponse: + rr.concludeCommand(nil, ErrorResponseToPgError(msg)) + } + + return msg, nil +} + +func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { + // Keep the first error that is recorded. Store the error before checking if the command is already concluded to + // allow for receiving an error after CommandComplete but before ReadyForQuery. + if err != nil && rr.err == nil { + rr.err = err + } + + if rr.commandConcluded { + return + } + + rr.commandTag = commandTag + rr.rowValues = nil + rr.commandConcluded = true +} + +// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. +type Batch struct { + buf []byte +} + +// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. +func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + batch.ExecPrepared("", paramValues, paramFormats, resultFormats) +} + +// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. +func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) +} + +// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a +// transaction is already in progress or SQL contains transaction control statements. +func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = newContextAlreadyDoneError(ctx) + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + + // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is + // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication + // channel to relay the error back. The practical effect of this is that the underlying Write error is not reported. + // The error the code reading the batch results receives will be a closed connection error. + // + // See https://github.com/jackc/pgx/issues/374. + go func() { + _, err := pgConn.conn.Write(batch.buf) + if err != nil { + pgConn.conn.Close() + } + }() + + return multiResult +} + +// EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include +// the surrounding single quotes. +// +// The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these +// conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future. +func (pgConn *PgConn) EscapeString(s string) (string, error) { + if pgConn.ParameterStatus("standard_conforming_strings") != "on" { + return "", errors.New("EscapeString must be run with standard_conforming_strings=on") + } + + if pgConn.ParameterStatus("client_encoding") != "UTF8" { + return "", errors.New("EscapeString must be run with client_encoding=UTF8") + } + + return strings.Replace(s, "'", "''", -1), nil +} + +// HijackedConn is the result of hijacking a connection. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +type HijackedConn struct { + Conn net.Conn // the underlying TCP or unix domain socket connection + PID uint32 // backend pid + SecretKey uint32 // key to use to send a cancel query message to the server + ParameterStatuses map[string]string // parameters that have been reported by the server + TxStatus byte + Frontend Frontend + Config *Config +} + +// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. +// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the +// raw connection after that (e.g. a load balancer or proxy). +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func (pgConn *PgConn) Hijack() (*HijackedConn, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + pgConn.status = connStatusClosed + + return &HijackedConn{ + Conn: pgConn.conn, + PID: pgConn.pid, + SecretKey: pgConn.secretKey, + ParameterStatuses: pgConn.parameterStatuses, + TxStatus: pgConn.txStatus, + Frontend: pgConn.frontend, + Config: pgConn.config, + }, nil +} + +// Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of +// PgConn.Hijack. The connection must be in an idle state. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func Construct(hc *HijackedConn) (*PgConn, error) { + pgConn := &PgConn{ + conn: hc.Conn, + pid: hc.PID, + secretKey: hc.SecretKey, + parameterStatuses: hc.ParameterStatuses, + txStatus: hc.TxStatus, + frontend: hc.Frontend, + config: hc.Config, + + status: connStatusIdle, + + wbuf: make([]byte, 0, wbufLen), + cleanupDone: make(chan struct{}), + } + + pgConn.contextWatcher = ctxwatch.NewContextWatcher( + func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { pgConn.conn.SetDeadline(time.Time{}) }, + ) + + return pgConn, nil +} diff --git a/vendor/github.com/jackc/pgconn/stmtcache/lru.go b/vendor/github.com/jackc/pgconn/stmtcache/lru.go new file mode 100644 index 000000000..f58f2ac34 --- /dev/null +++ b/vendor/github.com/jackc/pgconn/stmtcache/lru.go @@ -0,0 +1,157 @@ +package stmtcache + +import ( + "container/list" + "context" + "fmt" + "sync/atomic" + + "github.com/jackc/pgconn" +) + +var lruCount uint64 + +// LRU implements Cache with a Least Recently Used (LRU) cache. +type LRU struct { + conn *pgconn.PgConn + mode int + cap int + prepareCount int + m map[string]*list.Element + l *list.List + psNamePrefix string + stmtsToClear []string +} + +// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. +func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { + mustBeValidMode(mode) + mustBeValidCap(cap) + + n := atomic.AddUint64(&lruCount, 1) + + return &LRU{ + conn: conn, + mode: mode, + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + psNamePrefix: fmt.Sprintf("lrupsc_%d", n), + } +} + +// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. +func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { + // flush an outstanding bad statements + txStatus := c.conn.TxStatus() + if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 { + for _, stmt := range c.stmtsToClear { + err := c.clearStmt(ctx, stmt) + if err != nil { + return nil, err + } + } + } + + if el, ok := c.m[sql]; ok { + c.l.MoveToFront(el) + return el.Value.(*pgconn.StatementDescription), nil + } + + if c.l.Len() == c.cap { + err := c.removeOldest(ctx) + if err != nil { + return nil, err + } + } + + psd, err := c.prepare(ctx, sql) + if err != nil { + return nil, err + } + + el := c.l.PushFront(psd) + c.m[sql] = el + + return psd, nil +} + +// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. +func (c *LRU) Clear(ctx context.Context) error { + for c.l.Len() > 0 { + err := c.removeOldest(ctx) + if err != nil { + return err + } + } + + return nil +} + +func (c *LRU) StatementErrored(sql string, err error) { + pgErr, ok := err.(*pgconn.PgError) + if !ok { + return + } + + isInvalidCachedPlanError := pgErr.Severity == "ERROR" && + pgErr.Code == "0A000" && + pgErr.Message == "cached plan must not change result type" + if isInvalidCachedPlanError { + c.stmtsToClear = append(c.stmtsToClear, sql) + } +} + +func (c *LRU) clearStmt(ctx context.Context, sql string) error { + elem, inMap := c.m[sql] + if !inMap { + // The statement probably fell off the back of the list. In that case, we've + // ensured that it isn't in the cache, so we can declare victory. + return nil + } + + c.l.Remove(elem) + + psd := elem.Value.(*pgconn.StatementDescription) + delete(c.m, psd.SQL) + if c.mode == ModePrepare { + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() + } + return nil +} + +// Len returns the number of cached prepared statement descriptions. +func (c *LRU) Len() int { + return c.l.Len() +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *LRU) Cap() int { + return c.cap +} + +// Mode returns the mode of the cache (ModePrepare or ModeDescribe) +func (c *LRU) Mode() int { + return c.mode +} + +func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { + var name string + if c.mode == ModePrepare { + name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) + c.prepareCount += 1 + } + + return c.conn.Prepare(ctx, name, sql, nil) +} + +func (c *LRU) removeOldest(ctx context.Context) error { + oldest := c.l.Back() + c.l.Remove(oldest) + psd := oldest.Value.(*pgconn.StatementDescription) + delete(c.m, psd.SQL) + if c.mode == ModePrepare { + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() + } + return nil +} diff --git a/vendor/github.com/jackc/pgconn/stmtcache/stmtcache.go b/vendor/github.com/jackc/pgconn/stmtcache/stmtcache.go new file mode 100644 index 000000000..d083e1b4f --- /dev/null +++ b/vendor/github.com/jackc/pgconn/stmtcache/stmtcache.go @@ -0,0 +1,58 @@ +// Package stmtcache is a cache that can be used to implement lazy prepared statements. +package stmtcache + +import ( + "context" + + "github.com/jackc/pgconn" +) + +const ( + ModePrepare = iota // Cache should prepare named statements. + ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. +) + +// Cache prepares and caches prepared statement descriptions. +type Cache interface { + // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. + Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) + + // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. + Clear(ctx context.Context) error + + // StatementErrored informs the cache that the given statement resulted in an error when it + // was last used against the database. In some cases, this will cause the cache to maer that + // statement as bad. The bad statement will instead be flushed during the next call to Get + // that occurs outside of a failed transaction. + StatementErrored(sql string, err error) + + // Len returns the number of cached prepared statement descriptions. + Len() int + + // Cap returns the maximum number of cached prepared statement descriptions. + Cap() int + + // Mode returns the mode of the cache (ModePrepare or ModeDescribe) + Mode() int +} + +// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is +// the maximum size of the cache. +func New(conn *pgconn.PgConn, mode int, cap int) Cache { + mustBeValidMode(mode) + mustBeValidCap(cap) + + return NewLRU(conn, mode, cap) +} + +func mustBeValidMode(mode int) { + if mode != ModePrepare && mode != ModeDescribe { + panic("mode must be ModePrepare or ModeDescribe") + } +} + +func mustBeValidCap(cap int) { + if cap < 1 { + panic("cache must have cap of >= 1") + } +} diff --git a/vendor/github.com/jackc/pgio/.travis.yml b/vendor/github.com/jackc/pgio/.travis.yml new file mode 100644 index 000000000..e176228e8 --- /dev/null +++ b/vendor/github.com/jackc/pgio/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - 1.x + - tip + +matrix: + allow_failures: + - go: tip diff --git a/vendor/github.com/jackc/pgio/LICENSE b/vendor/github.com/jackc/pgio/LICENSE new file mode 100644 index 000000000..c1c4f50fc --- /dev/null +++ b/vendor/github.com/jackc/pgio/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2019 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/jackc/pgio/README.md b/vendor/github.com/jackc/pgio/README.md new file mode 100644 index 000000000..1952ed862 --- /dev/null +++ b/vendor/github.com/jackc/pgio/README.md @@ -0,0 +1,11 @@ +[](https://godoc.org/github.com/jackc/pgio) +[](https://travis-ci.org/jackc/pgio) + +# pgio + +Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. + +pgio provides functions for appending integers to a []byte while doing byte +order conversion. + +Extracted from original implementation in https://github.com/jackc/pgx. diff --git a/vendor/github.com/jackc/pgio/doc.go b/vendor/github.com/jackc/pgio/doc.go new file mode 100644 index 000000000..ef2dcc7f7 --- /dev/null +++ b/vendor/github.com/jackc/pgio/doc.go @@ -0,0 +1,6 @@ +// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. +/* +pgio provides functions for appending integers to a []byte while doing byte +order conversion. +*/ +package pgio diff --git a/vendor/github.com/jackc/pgio/go.mod b/vendor/github.com/jackc/pgio/go.mod new file mode 100644 index 000000000..c1efdddb6 --- /dev/null +++ b/vendor/github.com/jackc/pgio/go.mod @@ -0,0 +1,3 @@ +module github.com/jackc/pgio + +go 1.12 diff --git a/vendor/github.com/jackc/pgio/write.go b/vendor/github.com/jackc/pgio/write.go new file mode 100644 index 000000000..96aedf9dd --- /dev/null +++ b/vendor/github.com/jackc/pgio/write.go @@ -0,0 +1,40 @@ +package pgio + +import "encoding/binary" + +func AppendUint16(buf []byte, n uint16) []byte { + wp := len(buf) + buf = append(buf, 0, 0) + binary.BigEndian.PutUint16(buf[wp:], n) + return buf +} + +func AppendUint32(buf []byte, n uint32) []byte { + wp := len(buf) + buf = append(buf, 0, 0, 0, 0) + binary.BigEndian.PutUint32(buf[wp:], n) + return buf +} + +func AppendUint64(buf []byte, n uint64) []byte { + wp := len(buf) + buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) + binary.BigEndian.PutUint64(buf[wp:], n) + return buf +} + +func AppendInt16(buf []byte, n int16) []byte { + return AppendUint16(buf, uint16(n)) +} + +func AppendInt32(buf []byte, n int32) []byte { + return AppendUint32(buf, uint32(n)) +} + +func AppendInt64(buf []byte, n int64) []byte { + return AppendUint64(buf, uint64(n)) +} + +func SetInt32(buf []byte, n int32) { + binary.BigEndian.PutUint32(buf, uint32(n)) +} diff --git a/vendor/github.com/jackc/pgpassfile/.travis.yml b/vendor/github.com/jackc/pgpassfile/.travis.yml new file mode 100644 index 000000000..e176228e8 --- /dev/null +++ b/vendor/github.com/jackc/pgpassfile/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - 1.x + - tip + +matrix: + allow_failures: + - go: tip diff --git a/vendor/github.com/jackc/pgpassfile/LICENSE b/vendor/github.com/jackc/pgpassfile/LICENSE new file mode 100644 index 000000000..c1c4f50fc --- /dev/null +++ b/vendor/github.com/jackc/pgpassfile/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2019 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/jackc/pgpassfile/README.md b/vendor/github.com/jackc/pgpassfile/README.md new file mode 100644 index 000000000..661289ed8 --- /dev/null +++ b/vendor/github.com/jackc/pgpassfile/README.md @@ -0,0 +1,8 @@ +[](https://godoc.org/github.com/jackc/pgpassfile) +[](https://travis-ci.org/jackc/pgpassfile) + +# pgpassfile + +Package pgpassfile is a parser PostgreSQL .pgpass files. + +Extracted and rewritten from original implementation in https://github.com/jackc/pgx. diff --git a/vendor/github.com/jackc/pgpassfile/go.mod b/vendor/github.com/jackc/pgpassfile/go.mod new file mode 100644 index 000000000..48d90e313 --- /dev/null +++ b/vendor/github.com/jackc/pgpassfile/go.mod @@ -0,0 +1,5 @@ +module github.com/jackc/pgpassfile + +go 1.12 + +require github.com/stretchr/testify v1.3.0 diff --git a/vendor/github.com/jackc/pgpassfile/go.sum b/vendor/github.com/jackc/pgpassfile/go.sum new file mode 100644 index 000000000..4347755af --- /dev/null +++ b/vendor/github.com/jackc/pgpassfile/go.sum @@ -0,0 +1,7 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/vendor/github.com/jackc/pgpassfile/pgpass.go b/vendor/github.com/jackc/pgpassfile/pgpass.go new file mode 100644 index 000000000..f7eed3c84 --- /dev/null +++ b/vendor/github.com/jackc/pgpassfile/pgpass.go @@ -0,0 +1,110 @@ +// Package pgpassfile is a parser PostgreSQL .pgpass files. +package pgpassfile + +import ( + "bufio" + "io" + "os" + "regexp" + "strings" +) + +// Entry represents a line in a PG passfile. +type Entry struct { + Hostname string + Port string + Database string + Username string + Password string +} + +// Passfile is the in memory data structure representing a PG passfile. +type Passfile struct { + Entries []*Entry +} + +// ReadPassfile reads the file at path and parses it into a Passfile. +func ReadPassfile(path string) (*Passfile, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + return ParsePassfile(f) +} + +// ParsePassfile reads r and parses it into a Passfile. +func ParsePassfile(r io.Reader) (*Passfile, error) { + passfile := &Passfile{} + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + entry := parseLine(scanner.Text()) + if entry != nil { + passfile.Entries = append(passfile.Entries, entry) + } + } + + return passfile, scanner.Err() +} + +// Match (not colons or escaped colon or escaped backslash)+. Essentially gives a split on unescaped +// colon. +var colonSplitterRegexp = regexp.MustCompile("(([^:]|(\\:)))+") + +// var colonSplitterRegexp = regexp.MustCompile("((?:[^:]|(?:\\:)|(?:\\\\))+)") + +// parseLine parses a line into an *Entry. It returns nil on comment lines or any other unparsable +// line. +func parseLine(line string) *Entry { + const ( + tmpBackslash = "\r" + tmpColon = "\n" + ) + + line = strings.TrimSpace(line) + + if strings.HasPrefix(line, "#") { + return nil + } + + line = strings.Replace(line, `\\`, tmpBackslash, -1) + line = strings.Replace(line, `\:`, tmpColon, -1) + + parts := strings.Split(line, ":") + if len(parts) != 5 { + return nil + } + + // Unescape escaped colons and backslashes + for i := range parts { + parts[i] = strings.Replace(parts[i], tmpBackslash, `\`, -1) + parts[i] = strings.Replace(parts[i], tmpColon, `:`, -1) + } + + return &Entry{ + Hostname: parts[0], + Port: parts[1], + Database: parts[2], + Username: parts[3], + Password: parts[4], + } +} + +// FindPassword finds the password for the provided hostname, port, database, and username. For a +// Unix domain socket hostname must be set to "localhost". An empty string will be returned if no +// match is found. +// +// See https://www.postgresql.org/docs/current/libpq-pgpass.html for more password file information. +func (pf *Passfile) FindPassword(hostname, port, database, username string) (password string) { + for _, e := range pf.Entries { + if (e.Hostname == "*" || e.Hostname == hostname) && + (e.Port == "*" || e.Port == port) && + (e.Database == "*" || e.Database == database) && + (e.Username == "*" || e.Username == username) { + return e.Password + } + } + return "" +} diff --git a/vendor/github.com/jackc/pgproto3/v2/.travis.yml b/vendor/github.com/jackc/pgproto3/v2/.travis.yml new file mode 100644 index 000000000..e176228e8 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - 1.x + - tip + +matrix: + allow_failures: + - go: tip diff --git a/vendor/github.com/jackc/pgproto3/v2/LICENSE b/vendor/github.com/jackc/pgproto3/v2/LICENSE new file mode 100644 index 000000000..c1c4f50fc --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2019 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/jackc/pgproto3/v2/README.md b/vendor/github.com/jackc/pgproto3/v2/README.md new file mode 100644 index 000000000..565b3efd5 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/README.md @@ -0,0 +1,12 @@ +[](https://godoc.org/github.com/jackc/pgproto3) +[](https://travis-ci.org/jackc/pgproto3) + +# pgproto3 + +Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. + +pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more. + +See example/pgfortune for a playful example of a fake PostgreSQL server. + +Extracted from original implementation in https://github.com/jackc/pgx. diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_cleartext_password.go b/vendor/github.com/jackc/pgproto3/v2/authentication_cleartext_password.go new file mode 100644 index 000000000..241fa6005 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_cleartext_password.go @@ -0,0 +1,52 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required. +type AuthenticationCleartextPassword struct { +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationCleartextPassword) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationCleartextPassword) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { + if len(src) != 4 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeCleartextPassword { + return errors.New("bad auth type") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationCleartextPassword", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_md5_password.go b/vendor/github.com/jackc/pgproto3/v2/authentication_md5_password.go new file mode 100644 index 000000000..32ec0390e --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_md5_password.go @@ -0,0 +1,77 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required. +type AuthenticationMD5Password struct { + Salt [4]byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationMD5Password) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationMD5Password) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationMD5Password) Decode(src []byte) error { + if len(src) != 8 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeMD5Password { + return errors.New("bad auth type") + } + + copy(dst.Salt[:], src[4:8]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 12) + dst = pgio.AppendUint32(dst, AuthTypeMD5Password) + dst = append(dst, src.Salt[:]...) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Salt [4]byte + }{ + Type: "AuthenticationMD5Password", + Salt: src.Salt, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Salt [4]byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Salt = msg.Salt + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_ok.go b/vendor/github.com/jackc/pgproto3/v2/authentication_ok.go new file mode 100644 index 000000000..2b476fe51 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_ok.go @@ -0,0 +1,52 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationOk is a message sent from the backend indicating that authentication was successful. +type AuthenticationOk struct { +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationOk) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationOk) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationOk) Decode(src []byte) error { + if len(src) != 4 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeOk { + return errors.New("bad auth type") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationOk) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendUint32(dst, AuthTypeOk) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationOk) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "AuthenticationOK", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_sasl.go b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl.go new file mode 100644 index 000000000..bdcb2c367 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl.go @@ -0,0 +1,75 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required. +type AuthenticationSASL struct { + AuthMechanisms []string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASL) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationSASL) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASL) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASL { + return errors.New("bad auth type") + } + + authMechanisms := src[4:] + for len(authMechanisms) > 1 { + idx := bytes.IndexByte(authMechanisms, 0) + if idx > 0 { + dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx])) + authMechanisms = authMechanisms[idx+1:] + } + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASL) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASL) + + for _, s := range src.AuthMechanisms { + dst = append(dst, []byte(s)...) + dst = append(dst, 0) + } + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASL) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanisms []string + }{ + Type: "AuthenticationSASL", + AuthMechanisms: src.AuthMechanisms, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_continue.go b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_continue.go new file mode 100644 index 000000000..7f4a9c235 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_continue.go @@ -0,0 +1,81 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge. +type AuthenticationSASLContinue struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASLContinue) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLContinue) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASLContinue) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASLContinue { + return errors.New("bad auth type") + } + + dst.Data = src[4:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) + + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLContinue", + Data: string(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_final.go b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_final.go new file mode 100644 index 000000000..d82b9ee4d --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/authentication_sasl_final.go @@ -0,0 +1,81 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed. +type AuthenticationSASLFinal struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASLFinal) Backend() {} + +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLFinal) AuthenticationResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASLFinal) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASLFinal { + return errors.New("bad auth type") + } + + dst.Data = src[4:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) + + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Unmarshaler. +func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "AuthenticationSASLFinal", + Data: string(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/backend.go b/vendor/github.com/jackc/pgproto3/v2/backend.go new file mode 100644 index 000000000..e9ba38fc3 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/backend.go @@ -0,0 +1,204 @@ +package pgproto3 + +import ( + "encoding/binary" + "fmt" + "io" +) + +// Backend acts as a server for the PostgreSQL wire protocol version 3. +type Backend struct { + cr ChunkReader + w io.Writer + + // Frontend message flyweights + bind Bind + cancelRequest CancelRequest + _close Close + copyFail CopyFail + copyData CopyData + copyDone CopyDone + describe Describe + execute Execute + flush Flush + gssEncRequest GSSEncRequest + parse Parse + query Query + sslRequest SSLRequest + startupMessage StartupMessage + sync Sync + terminate Terminate + + bodyLen int + msgType byte + partialMsg bool + authType uint32 +} + +const ( + minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. + maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. +) + +// NewBackend creates a new Backend. +func NewBackend(cr ChunkReader, w io.Writer) *Backend { + return &Backend{cr: cr, w: w} +} + +// Send sends a message to the frontend. +func (b *Backend) Send(msg BackendMessage) error { + _, err := b.w.Write(msg.Encode(nil)) + return err +} + +// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method +// because the initial connection message is "special" and does not include the message type as the first byte. This +// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest. +func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { + buf, err := b.cr.Next(4) + if err != nil { + return nil, err + } + msgSize := int(binary.BigEndian.Uint32(buf) - 4) + + if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { + return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) + } + + buf, err = b.cr.Next(msgSize) + if err != nil { + return nil, translateEOFtoErrUnexpectedEOF(err) + } + + code := binary.BigEndian.Uint32(buf) + + switch code { + case ProtocolVersionNumber: + err = b.startupMessage.Decode(buf) + if err != nil { + return nil, err + } + return &b.startupMessage, nil + case sslRequestNumber: + err = b.sslRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.sslRequest, nil + case cancelRequestCode: + err = b.cancelRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.cancelRequest, nil + case gssEncReqNumber: + err = b.gssEncRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.gssEncRequest, nil + default: + return nil, fmt.Errorf("unknown startup message code: %d", code) + } +} + +// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive. +func (b *Backend) Receive() (FrontendMessage, error) { + if !b.partialMsg { + header, err := b.cr.Next(5) + if err != nil { + return nil, translateEOFtoErrUnexpectedEOF(err) + } + + b.msgType = header[0] + b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + b.partialMsg = true + } + + var msg FrontendMessage + switch b.msgType { + case 'B': + msg = &b.bind + case 'C': + msg = &b._close + case 'D': + msg = &b.describe + case 'E': + msg = &b.execute + case 'f': + msg = &b.copyFail + case 'd': + msg = &b.copyData + case 'c': + msg = &b.copyDone + case 'H': + msg = &b.flush + case 'P': + msg = &b.parse + case 'p': + switch b.authType { + case AuthTypeSASL: + msg = &SASLInitialResponse{} + case AuthTypeSASLContinue: + msg = &SASLResponse{} + case AuthTypeSASLFinal: + msg = &SASLResponse{} + case AuthTypeCleartextPassword, AuthTypeMD5Password: + fallthrough + default: + // to maintain backwards compatability + msg = &PasswordMessage{} + } + case 'Q': + msg = &b.query + case 'S': + msg = &b.sync + case 'X': + msg = &b.terminate + default: + return nil, fmt.Errorf("unknown message type: %c", b.msgType) + } + + msgBody, err := b.cr.Next(b.bodyLen) + if err != nil { + return nil, translateEOFtoErrUnexpectedEOF(err) + } + + b.partialMsg = false + + err = msg.Decode(msgBody) + return msg, err +} + +// SetAuthType sets the authentication type in the backend. +// Since multiple message types can start with 'p', SetAuthType allows +// contextual identification of FrontendMessages. For example, in the +// PG message flow documentation for PasswordMessage: +// +// Byte1('p') +// +// Identifies the message as a password response. Note that this is also used for +// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from +// the context. +// +// Since the Frontend does not know about the state of a backend, it is important +// to call SetAuthType() after an authentication request is received by the Frontend. +func (b *Backend) SetAuthType(authType uint32) error { + switch authType { + case AuthTypeOk, + AuthTypeCleartextPassword, + AuthTypeMD5Password, + AuthTypeSCMCreds, + AuthTypeGSS, + AuthTypeGSSCont, + AuthTypeSSPI, + AuthTypeSASL, + AuthTypeSASLContinue, + AuthTypeSASLFinal: + b.authType = authType + default: + return fmt.Errorf("authType not recognized: %d", authType) + } + + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/backend_key_data.go b/vendor/github.com/jackc/pgproto3/v2/backend_key_data.go new file mode 100644 index 000000000..ca20dd259 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/backend_key_data.go @@ -0,0 +1,51 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgio" +) + +type BackendKeyData struct { + ProcessID uint32 + SecretKey uint32 +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*BackendKeyData) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *BackendKeyData) Decode(src []byte) error { + if len(src) != 8 { + return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} + } + + dst.ProcessID = binary.BigEndian.Uint32(src[:4]) + dst.SecretKey = binary.BigEndian.Uint32(src[4:]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *BackendKeyData) Encode(dst []byte) []byte { + dst = append(dst, 'K') + dst = pgio.AppendUint32(dst, 12) + dst = pgio.AppendUint32(dst, src.ProcessID) + dst = pgio.AppendUint32(dst, src.SecretKey) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src BackendKeyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "BackendKeyData", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/big_endian.go b/vendor/github.com/jackc/pgproto3/v2/big_endian.go new file mode 100644 index 000000000..f7bdb97eb --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/big_endian.go @@ -0,0 +1,37 @@ +package pgproto3 + +import ( + "encoding/binary" +) + +type BigEndianBuf [8]byte + +func (b BigEndianBuf) Int16(n int16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, uint16(n)) + return buf +} + +func (b BigEndianBuf) Uint16(n uint16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, n) + return buf +} + +func (b BigEndianBuf) Int32(n int32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, uint32(n)) + return buf +} + +func (b BigEndianBuf) Uint32(n uint32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, n) + return buf +} + +func (b BigEndianBuf) Int64(n int64) []byte { + buf := b[0:8] + binary.BigEndian.PutUint64(buf, uint64(n)) + return buf +} diff --git a/vendor/github.com/jackc/pgproto3/v2/bind.go b/vendor/github.com/jackc/pgproto3/v2/bind.go new file mode 100644 index 000000000..e9664f59f --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/bind.go @@ -0,0 +1,216 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + + "github.com/jackc/pgio" +) + +type Bind struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters [][]byte + ResultFormatCodes []int16 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Bind) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Bind) Decode(src []byte) error { + *dst = Bind{} + + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.DestinationPortal = string(src[:idx]) + rp := idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.PreparedStatement = string(src[rp : rp+idx]) + rp += idx + 1 + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if parameterFormatCodeCount > 0 { + dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) + + if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < parameterFormatCodeCount; i++ { + dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if parameterCount > 0 { + dst.Parameters = make([][]byte, parameterCount) + + for i := 0; i < parameterCount; i++ { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + // null + if msgSize == -1 { + continue + } + + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + dst.Parameters[i] = src[rp : rp+msgSize] + rp += msgSize + } + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.ResultFormatCodes = make([]int16, resultFormatCodeCount) + if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < resultFormatCodeCount; i++ { + dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Bind) Encode(dst []byte) []byte { + dst = append(dst, 'B') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.DestinationPortal...) + dst = append(dst, 0) + dst = append(dst, src.PreparedStatement...) + dst = append(dst, 0) + + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) + for _, fc := range src.ParameterFormatCodes { + dst = pgio.AppendInt16(dst, fc) + } + + dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) + for _, p := range src.Parameters { + if p == nil { + dst = pgio.AppendInt32(dst, -1) + continue + } + + dst = pgio.AppendInt32(dst, int32(len(p))) + dst = append(dst, p...) + } + + dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) + for _, fc := range src.ResultFormatCodes { + dst = pgio.AppendInt16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Bind) MarshalJSON() ([]byte, error) { + formattedParameters := make([]map[string]string, len(src.Parameters)) + for i, p := range src.Parameters { + if p == nil { + continue + } + + textFormat := true + if len(src.ParameterFormatCodes) == 1 { + textFormat = src.ParameterFormatCodes[0] == 0 + } else if len(src.ParameterFormatCodes) > 1 { + textFormat = src.ParameterFormatCodes[i] == 0 + } + + if textFormat { + formattedParameters[i] = map[string]string{"text": string(p)} + } else { + formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)} + } + } + + return json.Marshal(struct { + Type string + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + }{ + Type: "Bind", + DestinationPortal: src.DestinationPortal, + PreparedStatement: src.PreparedStatement, + ParameterFormatCodes: src.ParameterFormatCodes, + Parameters: formattedParameters, + ResultFormatCodes: src.ResultFormatCodes, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Bind) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + dst.DestinationPortal = msg.DestinationPortal + dst.PreparedStatement = msg.PreparedStatement + dst.ParameterFormatCodes = msg.ParameterFormatCodes + dst.Parameters = make([][]byte, len(msg.Parameters)) + dst.ResultFormatCodes = msg.ResultFormatCodes + for n, parameter := range msg.Parameters { + dst.Parameters[n], err = getValueFromJSON(parameter) + if err != nil { + return fmt.Errorf("cannot get param %d: %w", n, err) + } + } + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/bind_complete.go b/vendor/github.com/jackc/pgproto3/v2/bind_complete.go new file mode 100644 index 000000000..3be256c89 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/bind_complete.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type BindComplete struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*BindComplete) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *BindComplete) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *BindComplete) Encode(dst []byte) []byte { + return append(dst, '2', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src BindComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "BindComplete", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/cancel_request.go b/vendor/github.com/jackc/pgproto3/v2/cancel_request.go new file mode 100644 index 000000000..942e404be --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/cancel_request.go @@ -0,0 +1,58 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +const cancelRequestCode = 80877102 + +type CancelRequest struct { + ProcessID uint32 + SecretKey uint32 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CancelRequest) Frontend() {} + +func (dst *CancelRequest) Decode(src []byte) error { + if len(src) != 12 { + return errors.New("bad cancel request size") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != cancelRequestCode { + return errors.New("bad cancel request code") + } + + dst.ProcessID = binary.BigEndian.Uint32(src[4:]) + dst.SecretKey = binary.BigEndian.Uint32(src[8:]) + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *CancelRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 16) + dst = pgio.AppendInt32(dst, cancelRequestCode) + dst = pgio.AppendUint32(dst, src.ProcessID) + dst = pgio.AppendUint32(dst, src.SecretKey) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CancelRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "CancelRequest", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/chunkreader.go b/vendor/github.com/jackc/pgproto3/v2/chunkreader.go new file mode 100644 index 000000000..92206f358 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/chunkreader.go @@ -0,0 +1,19 @@ +package pgproto3 + +import ( + "io" + + "github.com/jackc/chunkreader/v2" +) + +// ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package. +type ChunkReader interface { + // Next returns buf filled with the next n bytes. If an error (including a partial read) occurs, + // buf must be nil. Next must preserve any partially read data. Next must not reuse buf. + Next(n int) (buf []byte, err error) +} + +// NewChunkReader creates and returns a new default ChunkReader. +func NewChunkReader(r io.Reader) ChunkReader { + return chunkreader.New(r) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/close.go b/vendor/github.com/jackc/pgproto3/v2/close.go new file mode 100644 index 000000000..a45f2b930 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/close.go @@ -0,0 +1,89 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +type Close struct { + ObjectType byte // 'S' = prepared statement, 'P' = portal + Name string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Close) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Close) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "Close"} + } + + dst.ObjectType = src[0] + rp := 1 + + idx := bytes.IndexByte(src[rp:], 0) + if idx != len(src[rp:])-1 { + return &invalidMessageFormatErr{messageType: "Close"} + } + + dst.Name = string(src[rp : len(src)-1]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Close) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.ObjectType) + dst = append(dst, src.Name...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Close) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ObjectType string + Name string + }{ + Type: "Close", + ObjectType: string(src.ObjectType), + Name: src.Name, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Close) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Close.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/close_complete.go b/vendor/github.com/jackc/pgproto3/v2/close_complete.go new file mode 100644 index 000000000..1d7b8f085 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/close_complete.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CloseComplete struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CloseComplete) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CloseComplete) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CloseComplete) Encode(dst []byte) []byte { + return append(dst, '3', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CloseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CloseComplete", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/command_complete.go b/vendor/github.com/jackc/pgproto3/v2/command_complete.go new file mode 100644 index 000000000..cdc49f39f --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/command_complete.go @@ -0,0 +1,71 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgio" +) + +type CommandComplete struct { + CommandTag []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CommandComplete) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CommandComplete) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CommandComplete"} + } + + dst.CommandTag = src[:idx] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CommandComplete) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.CommandTag...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CommandComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + CommandTag string + }{ + Type: "CommandComplete", + CommandTag: string(src.CommandTag), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CommandComplete) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + CommandTag string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.CommandTag = []byte(msg.CommandTag) + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_both_response.go b/vendor/github.com/jackc/pgproto3/v2/copy_both_response.go new file mode 100644 index 000000000..fbd985d86 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/copy_both_response.go @@ -0,0 +1,95 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +type CopyBothResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CopyBothResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyBothResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyBothResponse) Encode(dst []byte) []byte { + dst = append(dst, 'W') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) + for _, fc := range src.ColumnFormatCodes { + dst = pgio.AppendUint16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyBothResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyBothResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyBothResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_data.go b/vendor/github.com/jackc/pgproto3/v2/copy_data.go new file mode 100644 index 000000000..128aa198c --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/copy_data.go @@ -0,0 +1,62 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgio" +) + +type CopyData struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CopyData) Backend() {} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CopyData) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyData) Decode(src []byte) error { + dst.Data = src + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyData) Encode(dst []byte) []byte { + dst = append(dst, 'd') + dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) + dst = append(dst, src.Data...) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "CopyData", + Data: hex.EncodeToString(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyData) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Data = []byte(msg.Data) + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_done.go b/vendor/github.com/jackc/pgproto3/v2/copy_done.go new file mode 100644 index 000000000..0e13282bf --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/copy_done.go @@ -0,0 +1,38 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CopyDone struct { +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CopyDone) Backend() {} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CopyDone) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyDone) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyDone) Encode(dst []byte) []byte { + return append(dst, 'c', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyDone) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CopyDone", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_fail.go b/vendor/github.com/jackc/pgproto3/v2/copy_fail.go new file mode 100644 index 000000000..78ff0b30b --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/copy_fail.go @@ -0,0 +1,53 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgio" +) + +type CopyFail struct { + Message string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CopyFail) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyFail) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CopyFail"} + } + + dst.Message = string(src[:idx]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyFail) Encode(dst []byte) []byte { + dst = append(dst, 'f') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Message...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyFail) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Message string + }{ + Type: "CopyFail", + Message: src.Message, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_in_response.go b/vendor/github.com/jackc/pgproto3/v2/copy_in_response.go new file mode 100644 index 000000000..80733adcf --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/copy_in_response.go @@ -0,0 +1,96 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +type CopyInResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*CopyInResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyInResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyInResponse) Encode(dst []byte) []byte { + dst = append(dst, 'G') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.OverallFormat) + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) + for _, fc := range src.ColumnFormatCodes { + dst = pgio.AppendUint16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyInResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyInResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyInResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyInResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/copy_out_response.go b/vendor/github.com/jackc/pgproto3/v2/copy_out_response.go new file mode 100644 index 000000000..5e607e3ac --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/copy_out_response.go @@ -0,0 +1,96 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +type CopyOutResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyOutResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *CopyOutResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *CopyOutResponse) Encode(dst []byte) []byte { + dst = append(dst, 'H') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.OverallFormat) + + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) + for _, fc := range src.ColumnFormatCodes { + dst = pgio.AppendUint16(dst, fc) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CopyOutResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyOutResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + OverallFormat string + ColumnFormatCodes []uint16 + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + if len(msg.OverallFormat) != 1 { + return errors.New("invalid length for CopyOutResponse.OverallFormat") + } + + dst.OverallFormat = msg.OverallFormat[0] + dst.ColumnFormatCodes = msg.ColumnFormatCodes + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/data_row.go b/vendor/github.com/jackc/pgproto3/v2/data_row.go new file mode 100644 index 000000000..637687616 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/data_row.go @@ -0,0 +1,142 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgio" +) + +type DataRow struct { + Values [][]byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*DataRow) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *DataRow) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + rp := 0 + fieldCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + // If the capacity of the values slice is too small OR substantially too + // large reallocate. This is too avoid one row with many columns from + // permanently allocating memory. + if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { + newCap := 32 + if newCap < fieldCount { + newCap = fieldCount + } + dst.Values = make([][]byte, fieldCount, newCap) + } else { + dst.Values = dst.Values[:fieldCount] + } + + for i := 0; i < fieldCount; i++ { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + // null + if msgSize == -1 { + dst.Values[i] = nil + } else { + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + + dst.Values[i] = src[rp : rp+msgSize : rp+msgSize] + rp += msgSize + } + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *DataRow) Encode(dst []byte) []byte { + dst = append(dst, 'D') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.Values))) + for _, v := range src.Values { + if v == nil { + dst = pgio.AppendInt32(dst, -1) + continue + } + + dst = pgio.AppendInt32(dst, int32(len(v))) + dst = append(dst, v...) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src DataRow) MarshalJSON() ([]byte, error) { + formattedValues := make([]map[string]string, len(src.Values)) + for i, v := range src.Values { + if v == nil { + continue + } + + var hasNonPrintable bool + for _, b := range v { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)} + } else { + formattedValues[i] = map[string]string{"text": string(v)} + } + } + + return json.Marshal(struct { + Type string + Values []map[string]string + }{ + Type: "DataRow", + Values: formattedValues, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *DataRow) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Values []map[string]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Values = make([][]byte, len(msg.Values)) + for n, parameter := range msg.Values { + var err error + dst.Values[n], err = getValueFromJSON(parameter) + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/describe.go b/vendor/github.com/jackc/pgproto3/v2/describe.go new file mode 100644 index 000000000..0d825db19 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/describe.go @@ -0,0 +1,88 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +type Describe struct { + ObjectType byte // 'S' = prepared statement, 'P' = portal + Name string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Describe) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Describe) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.ObjectType = src[0] + rp := 1 + + idx := bytes.IndexByte(src[rp:], 0) + if idx != len(src[rp:])-1 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.Name = string(src[rp : len(src)-1]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Describe) Encode(dst []byte) []byte { + dst = append(dst, 'D') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.ObjectType) + dst = append(dst, src.Name...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Describe) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ObjectType string + Name string + }{ + Type: "Describe", + ObjectType: string(src.ObjectType), + Name: src.Name, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *Describe) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + ObjectType string + Name string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.ObjectType) != 1 { + return errors.New("invalid length for Describe.ObjectType") + } + + dst.ObjectType = byte(msg.ObjectType[0]) + dst.Name = msg.Name + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/doc.go b/vendor/github.com/jackc/pgproto3/v2/doc.go new file mode 100644 index 000000000..8226dc983 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/doc.go @@ -0,0 +1,4 @@ +// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. +// +// See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages. +package pgproto3 diff --git a/vendor/github.com/jackc/pgproto3/v2/empty_query_response.go b/vendor/github.com/jackc/pgproto3/v2/empty_query_response.go new file mode 100644 index 000000000..2b85e744b --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/empty_query_response.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type EmptyQueryResponse struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*EmptyQueryResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *EmptyQueryResponse) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *EmptyQueryResponse) Encode(dst []byte) []byte { + return append(dst, 'I', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src EmptyQueryResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "EmptyQueryResponse", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/error_response.go b/vendor/github.com/jackc/pgproto3/v2/error_response.go new file mode 100644 index 000000000..ec51e0192 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/error_response.go @@ -0,0 +1,334 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "strconv" +) + +type ErrorResponse struct { + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*ErrorResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ErrorResponse) Decode(src []byte) error { + *dst = ErrorResponse{} + + buf := bytes.NewBuffer(src) + + for { + k, err := buf.ReadByte() + if err != nil { + return err + } + if k == 0 { + break + } + + vb, err := buf.ReadBytes(0) + if err != nil { + return err + } + v := string(vb[:len(vb)-1]) + + switch k { + case 'S': + dst.Severity = v + case 'V': + dst.SeverityUnlocalized = v + case 'C': + dst.Code = v + case 'M': + dst.Message = v + case 'D': + dst.Detail = v + case 'H': + dst.Hint = v + case 'P': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Position = int32(n) + case 'p': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.InternalPosition = int32(n) + case 'q': + dst.InternalQuery = v + case 'W': + dst.Where = v + case 's': + dst.SchemaName = v + case 't': + dst.TableName = v + case 'c': + dst.ColumnName = v + case 'd': + dst.DataTypeName = v + case 'n': + dst.ConstraintName = v + case 'F': + dst.File = v + case 'L': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Line = int32(n) + case 'R': + dst.Routine = v + + default: + if dst.UnknownFields == nil { + dst.UnknownFields = make(map[byte]string) + } + dst.UnknownFields[k] = v + } + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ErrorResponse) Encode(dst []byte) []byte { + return append(dst, src.marshalBinary('E')...) +} + +func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte(typeByte) + buf.Write(bigEndian.Uint32(0)) + + if src.Severity != "" { + buf.WriteByte('S') + buf.WriteString(src.Severity) + buf.WriteByte(0) + } + if src.SeverityUnlocalized != "" { + buf.WriteByte('V') + buf.WriteString(src.SeverityUnlocalized) + buf.WriteByte(0) + } + if src.Code != "" { + buf.WriteByte('C') + buf.WriteString(src.Code) + buf.WriteByte(0) + } + if src.Message != "" { + buf.WriteByte('M') + buf.WriteString(src.Message) + buf.WriteByte(0) + } + if src.Detail != "" { + buf.WriteByte('D') + buf.WriteString(src.Detail) + buf.WriteByte(0) + } + if src.Hint != "" { + buf.WriteByte('H') + buf.WriteString(src.Hint) + buf.WriteByte(0) + } + if src.Position != 0 { + buf.WriteByte('P') + buf.WriteString(strconv.Itoa(int(src.Position))) + buf.WriteByte(0) + } + if src.InternalPosition != 0 { + buf.WriteByte('p') + buf.WriteString(strconv.Itoa(int(src.InternalPosition))) + buf.WriteByte(0) + } + if src.InternalQuery != "" { + buf.WriteByte('q') + buf.WriteString(src.InternalQuery) + buf.WriteByte(0) + } + if src.Where != "" { + buf.WriteByte('W') + buf.WriteString(src.Where) + buf.WriteByte(0) + } + if src.SchemaName != "" { + buf.WriteByte('s') + buf.WriteString(src.SchemaName) + buf.WriteByte(0) + } + if src.TableName != "" { + buf.WriteByte('t') + buf.WriteString(src.TableName) + buf.WriteByte(0) + } + if src.ColumnName != "" { + buf.WriteByte('c') + buf.WriteString(src.ColumnName) + buf.WriteByte(0) + } + if src.DataTypeName != "" { + buf.WriteByte('d') + buf.WriteString(src.DataTypeName) + buf.WriteByte(0) + } + if src.ConstraintName != "" { + buf.WriteByte('n') + buf.WriteString(src.ConstraintName) + buf.WriteByte(0) + } + if src.File != "" { + buf.WriteByte('F') + buf.WriteString(src.File) + buf.WriteByte(0) + } + if src.Line != 0 { + buf.WriteByte('L') + buf.WriteString(strconv.Itoa(int(src.Line))) + buf.WriteByte(0) + } + if src.Routine != "" { + buf.WriteByte('R') + buf.WriteString(src.Routine) + buf.WriteByte(0) + } + + for k, v := range src.UnknownFields { + buf.WriteByte(k) + buf.WriteByte(0) + buf.WriteString(v) + buf.WriteByte(0) + } + + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes() +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src ErrorResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + }{ + Type: "ErrorResponse", + Severity: src.Severity, + SeverityUnlocalized: src.SeverityUnlocalized, + Code: src.Code, + Message: src.Message, + Detail: src.Detail, + Hint: src.Hint, + Position: src.Position, + InternalPosition: src.InternalPosition, + InternalQuery: src.InternalQuery, + Where: src.Where, + SchemaName: src.SchemaName, + TableName: src.TableName, + ColumnName: src.ColumnName, + DataTypeName: src.DataTypeName, + ConstraintName: src.ConstraintName, + File: src.File, + Line: src.Line, + Routine: src.Routine, + UnknownFields: src.UnknownFields, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ErrorResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Severity string + SeverityUnlocalized string // only in 9.6 and greater + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + dst.Severity = msg.Severity + dst.SeverityUnlocalized = msg.SeverityUnlocalized + dst.Code = msg.Code + dst.Message = msg.Message + dst.Detail = msg.Detail + dst.Hint = msg.Hint + dst.Position = msg.Position + dst.InternalPosition = msg.InternalPosition + dst.InternalQuery = msg.InternalQuery + dst.Where = msg.Where + dst.SchemaName = msg.SchemaName + dst.TableName = msg.TableName + dst.ColumnName = msg.ColumnName + dst.DataTypeName = msg.DataTypeName + dst.ConstraintName = msg.ConstraintName + dst.File = msg.File + dst.Line = msg.Line + dst.Routine = msg.Routine + + dst.UnknownFields = msg.UnknownFields + + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/execute.go b/vendor/github.com/jackc/pgproto3/v2/execute.go new file mode 100644 index 000000000..8bae61332 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/execute.go @@ -0,0 +1,65 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgio" +) + +type Execute struct { + Portal string + MaxRows uint32 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Execute) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Execute) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Portal = string(b[:len(b)-1]) + + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Execute"} + } + dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4)) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Execute) Encode(dst []byte) []byte { + dst = append(dst, 'E') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Portal...) + dst = append(dst, 0) + + dst = pgio.AppendUint32(dst, src.MaxRows) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Execute) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Portal string + MaxRows uint32 + }{ + Type: "Execute", + Portal: src.Portal, + MaxRows: src.MaxRows, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/flush.go b/vendor/github.com/jackc/pgproto3/v2/flush.go new file mode 100644 index 000000000..2725f6894 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/flush.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Flush struct{} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Flush) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Flush) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Flush) Encode(dst []byte) []byte { + return append(dst, 'H', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Flush) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Flush", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/frontend.go b/vendor/github.com/jackc/pgproto3/v2/frontend.go new file mode 100644 index 000000000..c33dfb084 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/frontend.go @@ -0,0 +1,201 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +// Frontend acts as a client for the PostgreSQL wire protocol version 3. +type Frontend struct { + cr ChunkReader + w io.Writer + + // Backend message flyweights + authenticationOk AuthenticationOk + authenticationCleartextPassword AuthenticationCleartextPassword + authenticationMD5Password AuthenticationMD5Password + authenticationSASL AuthenticationSASL + authenticationSASLContinue AuthenticationSASLContinue + authenticationSASLFinal AuthenticationSASLFinal + backendKeyData BackendKeyData + bindComplete BindComplete + closeComplete CloseComplete + commandComplete CommandComplete + copyBothResponse CopyBothResponse + copyData CopyData + copyInResponse CopyInResponse + copyOutResponse CopyOutResponse + copyDone CopyDone + dataRow DataRow + emptyQueryResponse EmptyQueryResponse + errorResponse ErrorResponse + functionCallResponse FunctionCallResponse + noData NoData + noticeResponse NoticeResponse + notificationResponse NotificationResponse + parameterDescription ParameterDescription + parameterStatus ParameterStatus + parseComplete ParseComplete + readyForQuery ReadyForQuery + rowDescription RowDescription + portalSuspended PortalSuspended + + bodyLen int + msgType byte + partialMsg bool + authType uint32 +} + +// NewFrontend creates a new Frontend. +func NewFrontend(cr ChunkReader, w io.Writer) *Frontend { + return &Frontend{cr: cr, w: w} +} + +// Send sends a message to the backend. +func (f *Frontend) Send(msg FrontendMessage) error { + _, err := f.w.Write(msg.Encode(nil)) + return err +} + +func translateEOFtoErrUnexpectedEOF(err error) error { + if err == io.EOF { + return io.ErrUnexpectedEOF + } + return err +} + +// Receive receives a message from the backend. The returned message is only valid until the next call to Receive. +func (f *Frontend) Receive() (BackendMessage, error) { + if !f.partialMsg { + header, err := f.cr.Next(5) + if err != nil { + return nil, translateEOFtoErrUnexpectedEOF(err) + } + + f.msgType = header[0] + f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + f.partialMsg = true + } + + msgBody, err := f.cr.Next(f.bodyLen) + if err != nil { + return nil, translateEOFtoErrUnexpectedEOF(err) + } + + f.partialMsg = false + + var msg BackendMessage + switch f.msgType { + case '1': + msg = &f.parseComplete + case '2': + msg = &f.bindComplete + case '3': + msg = &f.closeComplete + case 'A': + msg = &f.notificationResponse + case 'c': + msg = &f.copyDone + case 'C': + msg = &f.commandComplete + case 'd': + msg = &f.copyData + case 'D': + msg = &f.dataRow + case 'E': + msg = &f.errorResponse + case 'G': + msg = &f.copyInResponse + case 'H': + msg = &f.copyOutResponse + case 'I': + msg = &f.emptyQueryResponse + case 'K': + msg = &f.backendKeyData + case 'n': + msg = &f.noData + case 'N': + msg = &f.noticeResponse + case 'R': + var err error + msg, err = f.findAuthenticationMessageType(msgBody) + if err != nil { + return nil, err + } + case 's': + msg = &f.portalSuspended + case 'S': + msg = &f.parameterStatus + case 't': + msg = &f.parameterDescription + case 'T': + msg = &f.rowDescription + case 'V': + msg = &f.functionCallResponse + case 'W': + msg = &f.copyBothResponse + case 'Z': + msg = &f.readyForQuery + default: + return nil, fmt.Errorf("unknown message type: %c", f.msgType) + } + + err = msg.Decode(msgBody) + return msg, err +} + +// Authentication message type constants. +// See src/include/libpq/pqcomm.h for all +// constants. +const ( + AuthTypeOk = 0 + AuthTypeCleartextPassword = 3 + AuthTypeMD5Password = 5 + AuthTypeSCMCreds = 6 + AuthTypeGSS = 7 + AuthTypeGSSCont = 8 + AuthTypeSSPI = 9 + AuthTypeSASL = 10 + AuthTypeSASLContinue = 11 + AuthTypeSASLFinal = 12 +) + +func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) { + if len(src) < 4 { + return nil, errors.New("authentication message too short") + } + f.authType = binary.BigEndian.Uint32(src[:4]) + + switch f.authType { + case AuthTypeOk: + return &f.authenticationOk, nil + case AuthTypeCleartextPassword: + return &f.authenticationCleartextPassword, nil + case AuthTypeMD5Password: + return &f.authenticationMD5Password, nil + case AuthTypeSCMCreds: + return nil, errors.New("AuthTypeSCMCreds is unimplemented") + case AuthTypeGSS: + return nil, errors.New("AuthTypeGSS is unimplemented") + case AuthTypeGSSCont: + return nil, errors.New("AuthTypeGSSCont is unimplemented") + case AuthTypeSSPI: + return nil, errors.New("AuthTypeSSPI is unimplemented") + case AuthTypeSASL: + return &f.authenticationSASL, nil + case AuthTypeSASLContinue: + return &f.authenticationSASLContinue, nil + case AuthTypeSASLFinal: + return &f.authenticationSASLFinal, nil + default: + return nil, fmt.Errorf("unknown authentication type: %d", f.authType) + } +} + +// GetAuthType returns the authType used in the current state of the frontend. +// See SetAuthType for more information. +func (f *Frontend) GetAuthType() uint32 { + return f.authType +} diff --git a/vendor/github.com/jackc/pgproto3/v2/function_call_response.go b/vendor/github.com/jackc/pgproto3/v2/function_call_response.go new file mode 100644 index 000000000..53d642221 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/function_call_response.go @@ -0,0 +1,101 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgio" +) + +type FunctionCallResponse struct { + Result []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*FunctionCallResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *FunctionCallResponse) Decode(src []byte) error { + if len(src) < 4 { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + rp := 0 + resultSize := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if resultSize == -1 { + dst.Result = nil + return nil + } + + if len(src[rp:]) != resultSize { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + + dst.Result = src[rp:] + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *FunctionCallResponse) Encode(dst []byte) []byte { + dst = append(dst, 'V') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + if src.Result == nil { + dst = pgio.AppendInt32(dst, -1) + } else { + dst = pgio.AppendInt32(dst, int32(len(src.Result))) + dst = append(dst, src.Result...) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src FunctionCallResponse) MarshalJSON() ([]byte, error) { + var formattedValue map[string]string + var hasNonPrintable bool + for _, b := range src.Result { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)} + } else { + formattedValue = map[string]string{"text": string(src.Result)} + } + + return json.Marshal(struct { + Type string + Result map[string]string + }{ + Type: "FunctionCallResponse", + Result: formattedValue, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Result map[string]string + } + err := json.Unmarshal(data, &msg) + if err != nil { + return err + } + dst.Result, err = getValueFromJSON(msg.Result) + return err +} diff --git a/vendor/github.com/jackc/pgproto3/v2/go.mod b/vendor/github.com/jackc/pgproto3/v2/go.mod new file mode 100644 index 000000000..36041a94a --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/go.mod @@ -0,0 +1,9 @@ +module github.com/jackc/pgproto3/v2 + +go 1.12 + +require ( + github.com/jackc/chunkreader/v2 v2.0.0 + github.com/jackc/pgio v1.0.0 + github.com/stretchr/testify v1.4.0 +) diff --git a/vendor/github.com/jackc/pgproto3/v2/go.sum b/vendor/github.com/jackc/pgproto3/v2/go.sum new file mode 100644 index 000000000..dd9cd044f --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/go.sum @@ -0,0 +1,14 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/vendor/github.com/jackc/pgproto3/v2/gss_enc_request.go b/vendor/github.com/jackc/pgproto3/v2/gss_enc_request.go new file mode 100644 index 000000000..cf405a3e0 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/gss_enc_request.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +const gssEncReqNumber = 80877104 + +type GSSEncRequest struct { +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*GSSEncRequest) Frontend() {} + +func (dst *GSSEncRequest) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("gss encoding request too short") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != gssEncReqNumber { + return errors.New("bad gss encoding request code") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *GSSEncRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendInt32(dst, gssEncReqNumber) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src GSSEncRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "GSSEncRequest", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/no_data.go b/vendor/github.com/jackc/pgproto3/v2/no_data.go new file mode 100644 index 000000000..d8f85d38a --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/no_data.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type NoData struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*NoData) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *NoData) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NoData) Encode(dst []byte) []byte { + return append(dst, 'n', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src NoData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "NoData", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/notice_response.go b/vendor/github.com/jackc/pgproto3/v2/notice_response.go new file mode 100644 index 000000000..4ac28a791 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/notice_response.go @@ -0,0 +1,17 @@ +package pgproto3 + +type NoticeResponse ErrorResponse + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*NoticeResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *NoticeResponse) Decode(src []byte) error { + return (*ErrorResponse)(dst).Decode(src) +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NoticeResponse) Encode(dst []byte) []byte { + return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/notification_response.go b/vendor/github.com/jackc/pgproto3/v2/notification_response.go new file mode 100644 index 000000000..e762eb967 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/notification_response.go @@ -0,0 +1,73 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgio" +) + +type NotificationResponse struct { + PID uint32 + Channel string + Payload string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*NotificationResponse) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *NotificationResponse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + pid := binary.BigEndian.Uint32(buf.Next(4)) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + channel := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + payload := string(b[:len(b)-1]) + + *dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload} + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *NotificationResponse) Encode(dst []byte) []byte { + dst = append(dst, 'A') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint32(dst, src.PID) + dst = append(dst, src.Channel...) + dst = append(dst, 0) + dst = append(dst, src.Payload...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src NotificationResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + PID uint32 + Channel string + Payload string + }{ + Type: "NotificationResponse", + PID: src.PID, + Channel: src.Channel, + Payload: src.Payload, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/parameter_description.go b/vendor/github.com/jackc/pgproto3/v2/parameter_description.go new file mode 100644 index 000000000..e28965c8a --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/parameter_description.go @@ -0,0 +1,66 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgio" +) + +type ParameterDescription struct { + ParameterOIDs []uint32 +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*ParameterDescription) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ParameterDescription) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "ParameterDescription"} + } + + // Reported parameter count will be incorrect when number of args is greater than uint16 + buf.Next(2) + // Instead infer parameter count by remaining size of message + parameterCount := buf.Len() / 4 + + *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} + + for i := 0; i < parameterCount; i++ { + dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ParameterDescription) Encode(dst []byte) []byte { + dst = append(dst, 't') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) + for _, oid := range src.ParameterOIDs { + dst = pgio.AppendUint32(dst, oid) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src ParameterDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ParameterOIDs []uint32 + }{ + Type: "ParameterDescription", + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/parameter_status.go b/vendor/github.com/jackc/pgproto3/v2/parameter_status.go new file mode 100644 index 000000000..c4021d92f --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/parameter_status.go @@ -0,0 +1,66 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgio" +) + +type ParameterStatus struct { + Name string + Value string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*ParameterStatus) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ParameterStatus) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + name := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + value := string(b[:len(b)-1]) + + *dst = ParameterStatus{Name: name, Value: value} + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ParameterStatus) Encode(dst []byte) []byte { + dst = append(dst, 'S') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Name...) + dst = append(dst, 0) + dst = append(dst, src.Value...) + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (ps ParameterStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Value string + }{ + Type: "ParameterStatus", + Name: ps.Name, + Value: ps.Value, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/parse.go b/vendor/github.com/jackc/pgproto3/v2/parse.go new file mode 100644 index 000000000..723885d41 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/parse.go @@ -0,0 +1,88 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgio" +) + +type Parse struct { + Name string + Query string + ParameterOIDs []uint32 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Parse) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Parse) Decode(src []byte) error { + *dst = Parse{} + + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Name = string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + dst.Query = string(b[:len(b)-1]) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + for i := 0; i < parameterOIDCount; i++ { + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4))) + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Parse) Encode(dst []byte) []byte { + dst = append(dst, 'P') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, src.Name...) + dst = append(dst, 0) + dst = append(dst, src.Query...) + dst = append(dst, 0) + + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) + for _, oid := range src.ParameterOIDs { + dst = pgio.AppendUint32(dst, oid) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Parse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Query string + ParameterOIDs []uint32 + }{ + Type: "Parse", + Name: src.Name, + Query: src.Query, + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/parse_complete.go b/vendor/github.com/jackc/pgproto3/v2/parse_complete.go new file mode 100644 index 000000000..92c9498b6 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/parse_complete.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ParseComplete struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*ParseComplete) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ParseComplete) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ParseComplete) Encode(dst []byte) []byte { + return append(dst, '1', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src ParseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "ParseComplete", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/password_message.go b/vendor/github.com/jackc/pgproto3/v2/password_message.go new file mode 100644 index 000000000..cae76c50c --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/password_message.go @@ -0,0 +1,54 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgio" +) + +type PasswordMessage struct { + Password string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*PasswordMessage) Frontend() {} + +// Frontend identifies this message as an authentication response. +func (*PasswordMessage) InitialResponse() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *PasswordMessage) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Password = string(b[:len(b)-1]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *PasswordMessage) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) + + dst = append(dst, src.Password...) + dst = append(dst, 0) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src PasswordMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Password string + }{ + Type: "PasswordMessage", + Password: src.Password, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/pgproto3.go b/vendor/github.com/jackc/pgproto3/v2/pgproto3.go new file mode 100644 index 000000000..70c825e3c --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/pgproto3.go @@ -0,0 +1,65 @@ +package pgproto3 + +import ( + "encoding/hex" + "errors" + "fmt" +) + +// Message is the interface implemented by an object that can decode and encode +// a particular PostgreSQL message. +type Message interface { + // Decode is allowed and expected to retain a reference to data after + // returning (unlike encoding.BinaryUnmarshaler). + Decode(data []byte) error + + // Encode appends itself to dst and returns the new buffer. + Encode(dst []byte) []byte +} + +type FrontendMessage interface { + Message + Frontend() // no-op method to distinguish frontend from backend methods +} + +type BackendMessage interface { + Message + Backend() // no-op method to distinguish frontend from backend methods +} + +type AuthenticationResponseMessage interface { + BackendMessage + AuthenticationResponse() // no-op method to distinguish authentication responses +} + +type invalidMessageLenErr struct { + messageType string + expectedLen int + actualLen int +} + +func (e *invalidMessageLenErr) Error() string { + return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen) +} + +type invalidMessageFormatErr struct { + messageType string +} + +func (e *invalidMessageFormatErr) Error() string { + return fmt.Sprintf("%s body is invalid", e.messageType) +} + +// getValueFromJSON gets the value from a protocol message representation in JSON. +func getValueFromJSON(v map[string]string) ([]byte, error) { + if v == nil { + return nil, nil + } + if text, ok := v["text"]; ok { + return []byte(text), nil + } + if binary, ok := v["binary"]; ok { + return hex.DecodeString(binary) + } + return nil, errors.New("unknown protocol representation") +} diff --git a/vendor/github.com/jackc/pgproto3/v2/portal_suspended.go b/vendor/github.com/jackc/pgproto3/v2/portal_suspended.go new file mode 100644 index 000000000..1a9e7bfb1 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/portal_suspended.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type PortalSuspended struct{} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*PortalSuspended) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *PortalSuspended) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "PortalSuspended", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *PortalSuspended) Encode(dst []byte) []byte { + return append(dst, 's', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src PortalSuspended) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "PortalSuspended", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/query.go b/vendor/github.com/jackc/pgproto3/v2/query.go new file mode 100644 index 000000000..41c93b4a8 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/query.go @@ -0,0 +1,50 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" + + "github.com/jackc/pgio" +) + +type Query struct { + String string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Query) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Query) Decode(src []byte) error { + i := bytes.IndexByte(src, 0) + if i != len(src)-1 { + return &invalidMessageFormatErr{messageType: "Query"} + } + + dst.String = string(src[:i]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Query) Encode(dst []byte) []byte { + dst = append(dst, 'Q') + dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) + + dst = append(dst, src.String...) + dst = append(dst, 0) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Query) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + String string + }{ + Type: "Query", + String: src.String, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/ready_for_query.go b/vendor/github.com/jackc/pgproto3/v2/ready_for_query.go new file mode 100644 index 000000000..67a39be39 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/ready_for_query.go @@ -0,0 +1,61 @@ +package pgproto3 + +import ( + "encoding/json" + "errors" +) + +type ReadyForQuery struct { + TxStatus byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*ReadyForQuery) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *ReadyForQuery) Decode(src []byte) error { + if len(src) != 1 { + return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} + } + + dst.TxStatus = src[0] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *ReadyForQuery) Encode(dst []byte) []byte { + return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src ReadyForQuery) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + TxStatus string + }{ + Type: "ReadyForQuery", + TxStatus: string(src.TxStatus), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + TxStatus string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if len(msg.TxStatus) != 1 { + return errors.New("invalid length for ReadyForQuery.TxStatus") + } + dst.TxStatus = msg.TxStatus[0] + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/row_description.go b/vendor/github.com/jackc/pgproto3/v2/row_description.go new file mode 100644 index 000000000..a2e0d28e2 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/row_description.go @@ -0,0 +1,165 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgio" +) + +const ( + TextFormat = 0 + BinaryFormat = 1 +) + +type FieldDescription struct { + Name []byte + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 +} + +// MarshalJSON implements encoding/json.Marshaler. +func (fd FieldDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 + }{ + Name: string(fd.Name), + TableOID: fd.TableOID, + TableAttributeNumber: fd.TableAttributeNumber, + DataTypeOID: fd.DataTypeOID, + DataTypeSize: fd.DataTypeSize, + TypeModifier: fd.TypeModifier, + Format: fd.Format, + }) +} + +type RowDescription struct { + Fields []FieldDescription +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*RowDescription) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *RowDescription) Decode(src []byte) error { + + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + fieldCount := int(binary.BigEndian.Uint16(src)) + rp := 2 + + dst.Fields = dst.Fields[0:0] + + for i := 0; i < fieldCount; i++ { + var fd FieldDescription + + idx := bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + fd.Name = src[rp : rp+idx] + rp += idx + 1 + + // Since buf.Next() doesn't return an error if we hit the end of the buffer + // check Len ahead of time + if len(src[rp:]) < 18 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + + fd.TableOID = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + fd.TableAttributeNumber = binary.BigEndian.Uint16(src[rp:]) + rp += 2 + fd.DataTypeOID = binary.BigEndian.Uint32(src[rp:]) + rp += 4 + fd.DataTypeSize = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + fd.TypeModifier = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + fd.Format = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.Fields = append(dst.Fields, fd) + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *RowDescription) Encode(dst []byte) []byte { + dst = append(dst, 'T') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) + for _, fd := range src.Fields { + dst = append(dst, fd.Name...) + dst = append(dst, 0) + + dst = pgio.AppendUint32(dst, fd.TableOID) + dst = pgio.AppendUint16(dst, fd.TableAttributeNumber) + dst = pgio.AppendUint32(dst, fd.DataTypeOID) + dst = pgio.AppendInt16(dst, fd.DataTypeSize) + dst = pgio.AppendInt32(dst, fd.TypeModifier) + dst = pgio.AppendInt16(dst, fd.Format) + } + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src RowDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Fields []FieldDescription + }{ + Type: "RowDescription", + Fields: src.Fields, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *RowDescription) UnmarshalJSON(data []byte) error { + var msg struct { + Fields []struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + Format int16 + } + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + dst.Fields = make([]FieldDescription, len(msg.Fields)) + for n, field := range msg.Fields { + dst.Fields[n] = FieldDescription{ + Name: []byte(field.Name), + TableOID: field.TableOID, + TableAttributeNumber: field.TableAttributeNumber, + DataTypeOID: field.DataTypeOID, + DataTypeSize: field.DataTypeSize, + TypeModifier: field.TypeModifier, + Format: field.Format, + } + } + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/sasl_initial_response.go b/vendor/github.com/jackc/pgproto3/v2/sasl_initial_response.go new file mode 100644 index 000000000..f7e5f36a9 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/sasl_initial_response.go @@ -0,0 +1,94 @@ +package pgproto3 + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +type SASLInitialResponse struct { + AuthMechanism string + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*SASLInitialResponse) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *SASLInitialResponse) Decode(src []byte) error { + *dst = SASLInitialResponse{} + + rp := 0 + + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return errors.New("invalid SASLInitialResponse") + } + + dst.AuthMechanism = string(src[rp:idx]) + rp = idx + 1 + + rp += 4 // The rest of the message is data so we can just skip the size + dst.Data = src[rp:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *SASLInitialResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, []byte(src.AuthMechanism)...) + dst = append(dst, 0) + + dst = pgio.AppendInt32(dst, int32(len(src.Data))) + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src SASLInitialResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanism string + Data string + }{ + Type: "SASLInitialResponse", + AuthMechanism: src.AuthMechanism, + Data: hex.EncodeToString(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + AuthMechanism string + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + dst.AuthMechanism = msg.AuthMechanism + if msg.Data != "" { + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded + } + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/sasl_response.go b/vendor/github.com/jackc/pgproto3/v2/sasl_response.go new file mode 100644 index 000000000..41fb4c397 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/sasl_response.go @@ -0,0 +1,61 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgio" +) + +type SASLResponse struct { + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*SASLResponse) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *SASLResponse) Decode(src []byte) error { + *dst = SASLResponse{Data: src} + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *SASLResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) + + dst = append(dst, src.Data...) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src SASLResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "SASLResponse", + Data: hex.EncodeToString(src.Data), + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (dst *SASLResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + if msg.Data != "" { + decoded, err := hex.DecodeString(msg.Data) + if err != nil { + return err + } + dst.Data = decoded + } + return nil +} diff --git a/vendor/github.com/jackc/pgproto3/v2/ssl_request.go b/vendor/github.com/jackc/pgproto3/v2/ssl_request.go new file mode 100644 index 000000000..96ce489e5 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/ssl_request.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +const sslRequestNumber = 80877103 + +type SSLRequest struct { +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*SSLRequest) Frontend() {} + +func (dst *SSLRequest) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("ssl request too short") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != sslRequestNumber { + return errors.New("bad ssl request code") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *SSLRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendInt32(dst, sslRequestNumber) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src SSLRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "SSLRequest", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/startup_message.go b/vendor/github.com/jackc/pgproto3/v2/startup_message.go new file mode 100644 index 000000000..5f1cd24f7 --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/startup_message.go @@ -0,0 +1,96 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + + "github.com/jackc/pgio" +) + +const ProtocolVersionNumber = 196608 // 3.0 + +type StartupMessage struct { + ProtocolVersion uint32 + Parameters map[string]string +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*StartupMessage) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *StartupMessage) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("startup message too short") + } + + dst.ProtocolVersion = binary.BigEndian.Uint32(src) + rp := 4 + + if dst.ProtocolVersion != ProtocolVersionNumber { + return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) + } + + dst.Parameters = make(map[string]string) + for { + idx := bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + key := string(src[rp : rp+idx]) + rp += idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + value := string(src[rp : rp+idx]) + rp += idx + 1 + + dst.Parameters[key] = value + + if len(src[rp:]) == 1 { + if src[rp] != 0 { + return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) + } + break + } + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *StartupMessage) Encode(dst []byte) []byte { + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint32(dst, src.ProtocolVersion) + for k, v := range src.Parameters { + dst = append(dst, k...) + dst = append(dst, 0) + dst = append(dst, v...) + dst = append(dst, 0) + } + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src StartupMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "StartupMessage", + ProtocolVersion: src.ProtocolVersion, + Parameters: src.Parameters, + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/sync.go b/vendor/github.com/jackc/pgproto3/v2/sync.go new file mode 100644 index 000000000..5db8e07ac --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/sync.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Sync struct{} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Sync) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Sync) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Sync) Encode(dst []byte) []byte { + return append(dst, 'S', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Sync) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Sync", + }) +} diff --git a/vendor/github.com/jackc/pgproto3/v2/terminate.go b/vendor/github.com/jackc/pgproto3/v2/terminate.go new file mode 100644 index 000000000..135191eae --- /dev/null +++ b/vendor/github.com/jackc/pgproto3/v2/terminate.go @@ -0,0 +1,34 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Terminate struct{} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*Terminate) Frontend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *Terminate) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *Terminate) Encode(dst []byte) []byte { + return append(dst, 'X', 0, 0, 0, 4) +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src Terminate) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Terminate", + }) +} diff --git a/vendor/github.com/jackc/pgservicefile/.travis.yml b/vendor/github.com/jackc/pgservicefile/.travis.yml new file mode 100644 index 000000000..e176228e8 --- /dev/null +++ b/vendor/github.com/jackc/pgservicefile/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - 1.x + - tip + +matrix: + allow_failures: + - go: tip diff --git a/vendor/github.com/jackc/pgservicefile/LICENSE b/vendor/github.com/jackc/pgservicefile/LICENSE new file mode 100644 index 000000000..f1b4c2892 --- /dev/null +++ b/vendor/github.com/jackc/pgservicefile/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2020 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/jackc/pgservicefile/README.md b/vendor/github.com/jackc/pgservicefile/README.md new file mode 100644 index 000000000..e50ca1262 --- /dev/null +++ b/vendor/github.com/jackc/pgservicefile/README.md @@ -0,0 +1,6 @@ +[](https://godoc.org/github.com/jackc/pgservicefile) +[](https://travis-ci.org/jackc/pgservicefile) + +# pgservicefile + +Package pgservicefile is a parser for PostgreSQL service files (e.g. `.pg_service.conf`). diff --git a/vendor/github.com/jackc/pgservicefile/go.mod b/vendor/github.com/jackc/pgservicefile/go.mod new file mode 100644 index 000000000..051e9e0f4 --- /dev/null +++ b/vendor/github.com/jackc/pgservicefile/go.mod @@ -0,0 +1,5 @@ +module github.com/jackc/pgservicefile + +go 1.14 + +require github.com/stretchr/testify v1.5.1 diff --git a/vendor/github.com/vmihailenco/bufpool/go.sum b/vendor/github.com/jackc/pgservicefile/go.sum index 6074473ac..a80206ab1 100644 --- a/vendor/github.com/vmihailenco/bufpool/go.sum +++ b/vendor/github.com/jackc/pgservicefile/go.sum @@ -1,17 +1,10 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/vendor/github.com/jackc/pgservicefile/pgservicefile.go b/vendor/github.com/jackc/pgservicefile/pgservicefile.go new file mode 100644 index 000000000..797bbab9e --- /dev/null +++ b/vendor/github.com/jackc/pgservicefile/pgservicefile.go @@ -0,0 +1,79 @@ +// Package pgservicefile is a parser for PostgreSQL service files (e.g. .pg_service.conf). +package pgservicefile + +import ( + "bufio" + "errors" + "fmt" + "io" + "os" + "strings" +) + +type Service struct { + Name string + Settings map[string]string +} + +type Servicefile struct { + Services []*Service + servicesByName map[string]*Service +} + +// GetService returns the named service. +func (sf *Servicefile) GetService(name string) (*Service, error) { + service, present := sf.servicesByName[name] + if !present { + return nil, errors.New("not found") + } + return service, nil +} + +// ReadServicefile reads the file at path and parses it into a Servicefile. +func ReadServicefile(path string) (*Servicefile, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + return ParseServicefile(f) +} + +// ParseServicefile reads r and parses it into a Servicefile. +func ParseServicefile(r io.Reader) (*Servicefile, error) { + servicefile := &Servicefile{} + + var service *Service + scanner := bufio.NewScanner(r) + lineNum := 0 + for scanner.Scan() { + lineNum += 1 + line := scanner.Text() + line = strings.TrimSpace(line) + + if line == "" || strings.HasPrefix(line, "#") { + // ignore comments and empty lines + } else if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { + service = &Service{Name: line[1 : len(line)-1], Settings: make(map[string]string)} + servicefile.Services = append(servicefile.Services, service) + } else { + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("unable to parse line %d", lineNum) + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + service.Settings[key] = value + } + } + + servicefile.servicesByName = make(map[string]*Service, len(servicefile.Services)) + for _, service := range servicefile.Services { + servicefile.servicesByName[service.Name] = service + } + + return servicefile, scanner.Err() +} diff --git a/vendor/github.com/jackc/pgtype/.travis.yml b/vendor/github.com/jackc/pgtype/.travis.yml new file mode 100644 index 000000000..d67627350 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/.travis.yml @@ -0,0 +1,34 @@ +# source: https://github.com/jackc/pgx/blob/master/.travis.yml + +language: go + +go: + - 1.14.x + - 1.13.x + - tip + +# Derived from https://github.com/lib/pq/blob/master/.travis.yml +before_install: + - ./travis/before_install.bash + +env: + global: + - GO111MODULE=on + - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test + + matrix: + - PGVERSION=12 + - PGVERSION=11 + - PGVERSION=10 + - PGVERSION=9.6 + - PGVERSION=9.5 + +before_script: + - ./travis/before_script.bash + +script: + - ./travis/script.bash + +matrix: + allow_failures: + - go: tip
\ No newline at end of file diff --git a/vendor/github.com/jackc/pgtype/CHANGELOG.md b/vendor/github.com/jackc/pgtype/CHANGELOG.md new file mode 100644 index 000000000..64d96fa00 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/CHANGELOG.md @@ -0,0 +1,103 @@ +# 1.8.1 (July 24, 2021) + +* Cleaned up Go module dependency chain + +# 1.8.0 (July 10, 2021) + +* Maintain host bits for inet types (Cameron Daniel) +* Support pointers of wrapping structs (Ivan Daunis) +* Register JSONBArray at NewConnInfo() (Rueian) +* CompositeTextScanner handles backslash escapes + +# 1.7.0 (March 25, 2021) + +* Fix scanning int into **sql.Scanner implementor +* Add tsrange array type (Vasilii Novikov) +* Fix: escaped strings when they start or end with a newline char (Stephane Martin) +* Accept nil *time.Time in Time.Set +* Fix numeric NaN support +* Use Go 1.13 errors instead of xerrors + +# 1.6.2 (December 3, 2020) + +* Fix panic on assigning empty array to non-slice or array +* Fix text array parsing disambiguates NULL and "NULL" +* Fix Timestamptz.DecodeText with too short text + +# 1.6.1 (October 31, 2020) + +* Fix simple protocol empty array support + +# 1.6.0 (October 24, 2020) + +* Fix AssignTo pointer to pointer to slice and named types. +* Fix zero length array assignment (Simo Haasanen) +* Add float64, float32 convert to int2, int4, int8 (lqu3j) +* Support setting infinite timestamps (Erik Agsjö) +* Polygon improvements (duohedron) +* Fix Inet.Set with nil (Tomas Volf) + +# 1.5.0 (September 26, 2020) + +* Add slice of slice mapping to multi-dimensional arrays (Simo Haasanen) +* Fix JSONBArray +* Fix selecting empty array +* Text formatted values except bytea can be directly scanned to []byte +* Add JSON marshalling for UUID (bakmataliev) +* Improve point type conversions (bakmataliev) + +# 1.4.2 (July 22, 2020) + +* Fix encoding of a large composite data type (Yaz Saito) + +# 1.4.1 (July 14, 2020) + +* Fix ArrayType DecodeBinary empty array breaks future reads + +# 1.4.0 (June 27, 2020) + +* Add JSON support to ext/gofrs-uuid +* Performance improvements in Scan path +* Improved ext/shopspring-numeric binary decoding performance +* Add composite type support (Maxim Ivanov and Jack Christensen) +* Add better generic enum type support +* Add generic array type support +* Clarify and normalize Value semantics +* Fix hstore with empty string values +* Numeric supports NaN values (leighhopcroft) +* Add slice of pointer support to array types (megaturbo) +* Add jsonb array type (tserakhau) +* Allow converting intervals with months and days to duration + +# 1.3.0 (March 30, 2020) + +* Get implemented on T instead of *T +* Set will call Get on src if possible +* Range types Set method supports its own type, string, and nil +* Date.Set parses string +* Fix correct format verb for unknown type error (Robert Welin) +* Truncate nanoseconds in EncodeText for Timestamptz and Timestamp + +# 1.2.0 (February 5, 2020) + +* Add zeronull package for easier NULL <-> zero conversion +* Add JSON marshalling for shopspring-numeric extension +* Add JSON marshalling for Bool, Date, JSON/B, Timestamptz (Jeffrey Stiles) +* Fix null status in UnmarshalJSON for some types (Jeffrey Stiles) + +# 1.1.0 (January 11, 2020) + +* Add PostgreSQL time type support +* Add more automatic conversions of integer arrays of different types (Jean-Philippe Quéméner) + +# 1.0.3 (November 16, 2019) + +* Support initializing Array types from a slice of the value (Alex Gaynor) + +# 1.0.2 (October 22, 2019) + +* Fix scan into null into pointer to pointer implementing Decode* interface. (Jeremy Altavilla) + +# 1.0.1 (September 19, 2019) + +* Fix daterange OID diff --git a/vendor/github.com/jackc/pgtype/LICENSE b/vendor/github.com/jackc/pgtype/LICENSE new file mode 100644 index 000000000..5c486c39a --- /dev/null +++ b/vendor/github.com/jackc/pgtype/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2013-2021 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/jackc/pgtype/README.md b/vendor/github.com/jackc/pgtype/README.md new file mode 100644 index 000000000..77d59b313 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/README.md @@ -0,0 +1,8 @@ +[](https://godoc.org/github.com/jackc/pgtype) + + +# pgtype + +pgtype implements Go types for over 70 PostgreSQL types. pgtype is the type system underlying the +https://github.com/jackc/pgx PostgreSQL driver. These types support the binary format for enhanced performance with pgx. +They also support the database/sql `Scan` and `Value` interfaces and can be used with https://github.com/lib/pq. diff --git a/vendor/github.com/jackc/pgtype/aclitem.go b/vendor/github.com/jackc/pgtype/aclitem.go new file mode 100644 index 000000000..9f6587be7 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/aclitem.go @@ -0,0 +1,138 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// +type ACLItem struct { + String string + Status Status +} + +func (dst *ACLItem) Set(src interface{}) error { + if src == nil { + *dst = ACLItem{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case string: + *dst = ACLItem{String: value, Status: Present} + case *string: + if value == nil { + *dst = ACLItem{Status: Null} + } else { + *dst = ACLItem{String: *value, Status: Present} + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to ACLItem", value) + } + + return nil +} + +func (dst ACLItem) Get() interface{} { + switch dst.Status { + case Present: + return dst.String + case Null: + return nil + default: + return dst.Status + } +} + +func (src *ACLItem) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ACLItem{Status: Null} + return nil + } + + *dst = ACLItem{String: string(src), Status: Present} + return nil +} + +func (src ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.String...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ACLItem) Scan(src interface{}) error { + if src == nil { + *dst = ACLItem{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src ACLItem) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/vendor/github.com/jackc/pgtype/aclitem_array.go b/vendor/github.com/jackc/pgtype/aclitem_array.go new file mode 100644 index 000000000..4e3be3bda --- /dev/null +++ b/vendor/github.com/jackc/pgtype/aclitem_array.go @@ -0,0 +1,428 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "fmt" + "reflect" +) + +type ACLItemArray struct { + Elements []ACLItem + Dimensions []ArrayDimension + Status Status +} + +func (dst *ACLItemArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = ACLItemArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + elements := make([]ACLItem, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = ACLItemArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + elements := make([]ACLItem, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = ACLItemArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []ACLItem: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + *dst = ACLItemArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = ACLItemArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for ACLItemArray", src) + } + if elementsLength == 0 { + *dst = ACLItemArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to ACLItemArray", src) + } + + *dst = ACLItemArray{ + Elements: make([]ACLItem, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]ACLItem, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to ACLItemArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *ACLItemArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to ACLItemArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in ACLItemArray", err) + } + index++ + + return index, nil +} + +func (dst ACLItemArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *ACLItemArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from ACLItemArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from ACLItemArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ACLItemArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []ACLItem + + if len(uta.Elements) > 0 { + elements = make([]ACLItem, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem ACLItem + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (src ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ACLItemArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src ACLItemArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/array.go b/vendor/github.com/jackc/pgtype/array.go new file mode 100644 index 000000000..3d5930c1c --- /dev/null +++ b/vendor/github.com/jackc/pgtype/array.go @@ -0,0 +1,381 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "reflect" + "strconv" + "strings" + "unicode" + + "github.com/jackc/pgio" +) + +// Information on the internals of PostgreSQL arrays can be found in +// src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of +// particular interest is the array_send function. + +type ArrayHeader struct { + ContainsNull bool + ElementOID int32 + Dimensions []ArrayDimension +} + +type ArrayDimension struct { + Length int32 + LowerBound int32 +} + +func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { + if len(src) < 12 { + return 0, fmt.Errorf("array header too short: %d", len(src)) + } + + rp := 0 + + numDims := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 + rp += 4 + + dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if numDims > 0 { + dst.Dimensions = make([]ArrayDimension, numDims) + } + if len(src) < 12+numDims*8 { + return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) + } + for i := range dst.Dimensions { + dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + } + + return rp, nil +} + +func (src ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { + buf = pgio.AppendInt32(buf, int32(len(src.Dimensions))) + + var containsNull int32 + if src.ContainsNull { + containsNull = 1 + } + buf = pgio.AppendInt32(buf, containsNull) + + buf = pgio.AppendInt32(buf, src.ElementOID) + + for i := range src.Dimensions { + buf = pgio.AppendInt32(buf, src.Dimensions[i].Length) + buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound) + } + + return buf +} + +type UntypedTextArray struct { + Elements []string + Quoted []bool + Dimensions []ArrayDimension +} + +func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { + dst := &UntypedTextArray{} + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + var explicitDimensions []ArrayDimension + + // Array has explicit dimensions + if r == '[' { + buf.UnreadRune() + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '=' { + break + } else if r != '[' { + return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) + } + + lower, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ':' { + return nil, fmt.Errorf("invalid array, expected ':' got %v", r) + } + + upper, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ']' { + return nil, fmt.Errorf("invalid array, expected ']' got %v", r) + } + + explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + } + + if r != '{' { + return nil, fmt.Errorf("invalid array, expected '{': %v", err) + } + + implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} + + // Consume all initial opening brackets. This provides number of dimensions. + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '{' { + implicitDimensions[len(implicitDimensions)-1].Length = 1 + implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1}) + } else { + buf.UnreadRune() + break + } + } + currentDim := len(implicitDimensions) - 1 + counterDim := currentDim + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + switch r { + case '{': + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + currentDim++ + case ',': + case '}': + currentDim-- + if currentDim < counterDim { + counterDim = currentDim + } + default: + buf.UnreadRune() + value, quoted, err := arrayParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid array value: %v", err) + } + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + dst.Quoted = append(dst.Quoted, quoted) + dst.Elements = append(dst.Elements, value) + } + + if currentDim < 0 { + break + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + if len(dst.Elements) == 0 { + dst.Dimensions = nil + } else if len(explicitDimensions) > 0 { + dst.Dimensions = explicitDimensions + } else { + dst.Dimensions = implicitDimensions + } + + return dst, nil +} + +func skipWhitespace(buf *bytes.Buffer) { + var r rune + var err error + for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() { + } + + if err != io.EOF { + buf.UnreadRune() + } +} + +func arrayParseValue(buf *bytes.Buffer) (string, bool, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", false, err + } + if r == '"' { + return arrayParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", false, err + } + + switch r { + case ',', '}': + buf.UnreadRune() + return s.String(), false, nil + } + + s.WriteRune(r) + } +} + +func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", false, err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", false, err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", false, err + } + buf.UnreadRune() + return s.String(), true, nil + } + s.WriteRune(r) + } +} + +func arrayParseInteger(buf *bytes.Buffer) (int32, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return 0, err + } + + if '0' <= r && r <= '9' { + s.WriteRune(r) + } else { + buf.UnreadRune() + n, err := strconv.ParseInt(s.String(), 10, 32) + if err != nil { + return 0, err + } + return int32(n), nil + } + } +} + +func EncodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte { + var customDimensions bool + for _, dim := range dimensions { + if dim.LowerBound != 1 { + customDimensions = true + } + } + + if !customDimensions { + return buf + } + + for _, dim := range dimensions { + buf = append(buf, '[') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...) + buf = append(buf, ':') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...) + buf = append(buf, ']') + } + + return append(buf, '=') +} + +var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteArrayElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func isSpace(ch byte) bool { + // see https://github.com/postgres/postgres/blob/REL_12_STABLE/src/backend/parser/scansup.c#L224 + return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f' +} + +func QuoteArrayElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { + return quoteArrayElement(src) + } + return src +} + +func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + length := value.Len() + if 0 == elementsLength { + elementsLength = length + } else { + elementsLength *= length + } + dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1}) + for i := 0; i < length; i++ { + if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok { + return d, l, true + } + } + } + return dimensions, elementsLength, true +} diff --git a/vendor/github.com/jackc/pgtype/array_type.go b/vendor/github.com/jackc/pgtype/array_type.go new file mode 100644 index 000000000..1bd0244b7 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/array_type.go @@ -0,0 +1,353 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +// ArrayType represents an array type. While it implements Value, this is only in service of its type conversion duties +// when registered as a data type in a ConnType. It should not be used directly as a Value. ArrayType is a convenience +// type for types that do not have an concrete array type. +type ArrayType struct { + elements []ValueTranscoder + dimensions []ArrayDimension + + typeName string + newElement func() ValueTranscoder + + elementOID uint32 + status Status +} + +func NewArrayType(typeName string, elementOID uint32, newElement func() ValueTranscoder) *ArrayType { + return &ArrayType{typeName: typeName, elementOID: elementOID, newElement: newElement} +} + +func (at *ArrayType) NewTypeValue() Value { + return &ArrayType{ + elements: at.elements, + dimensions: at.dimensions, + status: at.status, + + typeName: at.typeName, + elementOID: at.elementOID, + newElement: at.newElement, + } +} + +func (at *ArrayType) TypeName() string { + return at.typeName +} + +func (dst *ArrayType) setNil() { + dst.elements = nil + dst.dimensions = nil + dst.status = Null +} + +func (dst *ArrayType) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + dst.setNil() + return nil + } + + sliceVal := reflect.ValueOf(src) + if sliceVal.Kind() != reflect.Slice { + return fmt.Errorf("cannot set non-slice") + } + + if sliceVal.IsNil() { + dst.setNil() + return nil + } + + dst.elements = make([]ValueTranscoder, sliceVal.Len()) + for i := range dst.elements { + v := dst.newElement() + err := v.Set(sliceVal.Index(i).Interface()) + if err != nil { + return err + } + + dst.elements[i] = v + } + dst.dimensions = []ArrayDimension{{Length: int32(len(dst.elements)), LowerBound: 1}} + dst.status = Present + + return nil +} + +func (dst ArrayType) Get() interface{} { + switch dst.status { + case Present: + elementValues := make([]interface{}, len(dst.elements)) + for i := range dst.elements { + elementValues[i] = dst.elements[i].Get() + } + return elementValues + case Null: + return nil + default: + return dst.status + } +} + +func (src *ArrayType) AssignTo(dst interface{}) error { + ptrSlice := reflect.ValueOf(dst) + if ptrSlice.Kind() != reflect.Ptr { + return fmt.Errorf("cannot assign to non-pointer") + } + + sliceVal := ptrSlice.Elem() + sliceType := sliceVal.Type() + + if sliceType.Kind() != reflect.Slice { + return fmt.Errorf("cannot assign to pointer to non-slice") + } + + switch src.status { + case Present: + slice := reflect.MakeSlice(sliceType, len(src.elements), len(src.elements)) + elemType := sliceType.Elem() + + for i := range src.elements { + ptrElem := reflect.New(elemType) + err := src.elements[i].AssignTo(ptrElem.Interface()) + if err != nil { + return err + } + + slice.Index(i).Set(ptrElem.Elem()) + } + + sliceVal.Set(slice) + return nil + case Null: + sliceVal.Set(reflect.Zero(sliceType)) + return nil + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []ValueTranscoder + + if len(uta.Elements) > 0 { + elements = make([]ValueTranscoder, len(uta.Elements)) + + for i, s := range uta.Elements { + elem := dst.newElement() + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + dst.elements = elements + dst.dimensions = uta.Dimensions + dst.status = Present + + return nil +} + +func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + dst.setNil() + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + var elements []ValueTranscoder + + if len(arrayHeader.Dimensions) == 0 { + dst.elements = elements + dst.dimensions = arrayHeader.Dimensions + dst.status = Present + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements = make([]ValueTranscoder, elementCount) + + for i := range elements { + elem := dst.newElement() + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elem.DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + + dst.elements = elements + dst.dimensions = arrayHeader.Dimensions + dst.status = Present + + return nil +} + +func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.dimensions)) + dimElemCounts[len(src.dimensions)-1] = int(src.dimensions[len(src.dimensions)-1].Length) + for i := len(src.dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.dimensions, + ElementOID: int32(src.elementOID), + } + + for i := range src.elements { + if src.elements[i].Get() == nil { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ArrayType) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src ArrayType) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/bit.go b/vendor/github.com/jackc/pgtype/bit.go new file mode 100644 index 000000000..c1709e6b9 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/bit.go @@ -0,0 +1,45 @@ +package pgtype + +import ( + "database/sql/driver" +) + +type Bit Varbit + +func (dst *Bit) Set(src interface{}) error { + return (*Varbit)(dst).Set(src) +} + +func (dst Bit) Get() interface{} { + return (Varbit)(dst).Get() +} + +func (src *Bit) AssignTo(dst interface{}) error { + return (*Varbit)(src).AssignTo(dst) +} + +func (dst *Bit) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Varbit)(dst).DecodeBinary(ci, src) +} + +func (src Bit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Varbit)(src).EncodeBinary(ci, buf) +} + +func (dst *Bit) DecodeText(ci *ConnInfo, src []byte) error { + return (*Varbit)(dst).DecodeText(ci, src) +} + +func (src Bit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Varbit)(src).EncodeText(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bit) Scan(src interface{}) error { + return (*Varbit)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bit) Value() (driver.Value, error) { + return (Varbit)(src).Value() +} diff --git a/vendor/github.com/jackc/pgtype/bool.go b/vendor/github.com/jackc/pgtype/bool.go new file mode 100644 index 000000000..676c8e5d3 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/bool.go @@ -0,0 +1,217 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "strconv" +) + +type Bool struct { + Bool bool + Status Status +} + +func (dst *Bool) Set(src interface{}) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case bool: + *dst = Bool{Bool: value, Status: Present} + case string: + bb, err := strconv.ParseBool(value) + if err != nil { + return err + } + *dst = Bool{Bool: bb, Status: Present} + case *bool: + if value == nil { + *dst = Bool{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Bool{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingBoolType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bool", value) + } + + return nil +} + +func (dst Bool) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bool + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Bool) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *bool: + *v = src.Bool + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + *dst = Bool{Bool: src[0] == 't', Status: Present} + return nil +} + +func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) + } + + *dst = Bool{Bool: src[0] == 1, Status: Present} + return nil +} + +func (src Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if src.Bool { + buf = append(buf, 't') + } else { + buf = append(buf, 'f') + } + + return buf, nil +} + +func (src Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if src.Bool { + buf = append(buf, 1) + } else { + buf = append(buf, 0) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bool) Scan(src interface{}) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + + switch src := src.(type) { + case bool: + *dst = Bool{Bool: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bool) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bool, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src Bool) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + if src.Bool { + return []byte("true"), nil + } else { + return []byte("false"), nil + } + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *Bool) UnmarshalJSON(b []byte) error { + var v *bool + err := json.Unmarshal(b, &v) + if err != nil { + return err + } + + if v == nil { + *dst = Bool{Status: Null} + } else { + *dst = Bool{Bool: *v, Status: Present} + } + + return nil +} diff --git a/vendor/github.com/jackc/pgtype/bool_array.go b/vendor/github.com/jackc/pgtype/bool_array.go new file mode 100644 index 000000000..6558d971c --- /dev/null +++ b/vendor/github.com/jackc/pgtype/bool_array.go @@ -0,0 +1,517 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type BoolArray struct { + Elements []Bool + Dimensions []ArrayDimension + Status Status +} + +func (dst *BoolArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = BoolArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []bool: + if value == nil { + *dst = BoolArray{Status: Null} + } else if len(value) == 0 { + *dst = BoolArray{Status: Present} + } else { + elements := make([]Bool, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BoolArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*bool: + if value == nil { + *dst = BoolArray{Status: Null} + } else if len(value) == 0 { + *dst = BoolArray{Status: Present} + } else { + elements := make([]Bool, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BoolArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Bool: + if value == nil { + *dst = BoolArray{Status: Null} + } else if len(value) == 0 { + *dst = BoolArray{Status: Present} + } else { + *dst = BoolArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = BoolArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for BoolArray", src) + } + if elementsLength == 0 { + *dst = BoolArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to BoolArray", src) + } + + *dst = BoolArray{ + Elements: make([]Bool, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bool, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to BoolArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in BoolArray", err) + } + index++ + + return index, nil +} + +func (dst BoolArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *BoolArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]bool: + *v = make([]bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*bool: + *v = make([]*bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from BoolArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from BoolArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BoolArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Bool + + if len(uta.Elements) > 0 { + elements = make([]Bool, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Bool + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BoolArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = BoolArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Bool, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("bool"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "bool") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *BoolArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src BoolArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/box.go b/vendor/github.com/jackc/pgtype/box.go new file mode 100644 index 000000000..27fb829ee --- /dev/null +++ b/vendor/github.com/jackc/pgtype/box.go @@ -0,0 +1,165 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Box struct { + P [2]Vec2 + Status Status +} + +func (dst *Box) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Box", src) +} + +func (dst Box) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Box) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + if len(src) < 11 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + str := string(src[1:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-1] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + return nil +} + +func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + *dst = Box{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Status: Present, + } + return nil +} + +func (src Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(src.P[0].X, 'f', -1, 64), + strconv.FormatFloat(src.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(src.P[1].X, 'f', -1, 64), + strconv.FormatFloat(src.P[1].Y, 'f', -1, 64), + )...) + return buf, nil +} + +func (src Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Box) Scan(src interface{}) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Box) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/bpchar.go b/vendor/github.com/jackc/pgtype/bpchar.go new file mode 100644 index 000000000..e4d058e92 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/bpchar.go @@ -0,0 +1,76 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// BPChar is fixed-length, blank padded char type +// character(n), char(n) +type BPChar Text + +// Set converts from src to dst. +func (dst *BPChar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +// Get returns underlying value +func (dst BPChar) Get() interface{} { + return (Text)(dst).Get() +} + +// AssignTo assigns from src to dst. +func (src *BPChar) AssignTo(dst interface{}) error { + if src.Status == Present { + switch v := dst.(type) { + case *rune: + runes := []rune(src.String) + if len(runes) == 1 { + *v = runes[0] + return nil + } + } + } + return (*Text)(src).AssignTo(dst) +} + +func (BPChar) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *BPChar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *BPChar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (BPChar) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src BPChar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) +} + +func (src BPChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *BPChar) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src BPChar) Value() (driver.Value, error) { + return (Text)(src).Value() +} + +func (src BPChar) MarshalJSON() ([]byte, error) { + return (Text)(src).MarshalJSON() +} + +func (dst *BPChar) UnmarshalJSON(b []byte) error { + return (*Text)(dst).UnmarshalJSON(b) +} diff --git a/vendor/github.com/jackc/pgtype/bpchar_array.go b/vendor/github.com/jackc/pgtype/bpchar_array.go new file mode 100644 index 000000000..8e7922142 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/bpchar_array.go @@ -0,0 +1,517 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type BPCharArray struct { + Elements []BPChar + Dimensions []ArrayDimension + Status Status +} + +func (dst *BPCharArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = BPCharArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = BPCharArray{Status: Null} + } else if len(value) == 0 { + *dst = BPCharArray{Status: Present} + } else { + elements := make([]BPChar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BPCharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = BPCharArray{Status: Null} + } else if len(value) == 0 { + *dst = BPCharArray{Status: Present} + } else { + elements := make([]BPChar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = BPCharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []BPChar: + if value == nil { + *dst = BPCharArray{Status: Null} + } else if len(value) == 0 { + *dst = BPCharArray{Status: Present} + } else { + *dst = BPCharArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = BPCharArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for BPCharArray", src) + } + if elementsLength == 0 { + *dst = BPCharArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to BPCharArray", src) + } + + *dst = BPCharArray{ + Elements: make([]BPChar, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]BPChar, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to BPCharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *BPCharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to BPCharArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in BPCharArray", err) + } + index++ + + return index, nil +} + +func (dst BPCharArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *BPCharArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from BPCharArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from BPCharArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BPCharArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []BPChar + + if len(uta.Elements) > 0 { + elements = make([]BPChar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem BPChar + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = BPCharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = BPCharArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = BPCharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]BPChar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = BPCharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("bpchar"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "bpchar") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *BPCharArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src BPCharArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/bytea.go b/vendor/github.com/jackc/pgtype/bytea.go new file mode 100644 index 000000000..67eba3502 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/bytea.go @@ -0,0 +1,163 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/hex" + "fmt" +) + +type Bytea struct { + Bytes []byte + Status Status +} + +func (dst *Bytea) Set(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case []byte: + if value != nil { + *dst = Bytea{Bytes: value, Status: Present} + } else { + *dst = Bytea{Status: Null} + } + default: + if originalSrc, ok := underlyingBytesType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bytea", value) + } + + return nil +} + +func (dst Bytea) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bytes + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Bytea) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *[]byte: + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +// DecodeText only supports the hex format. This has been the default since +// PostgreSQL 9.0. +func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { + return fmt.Errorf("invalid hex format") + } + + buf := make([]byte, (len(src)-2)/2) + _, err := hex.Decode(buf, src[2:]) + if err != nil { + return err + } + + *dst = Bytea{Bytes: buf, Status: Present} + return nil +} + +func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + *dst = Bytea{Bytes: src, Status: Present} + return nil +} + +func (src Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, `\x`...) + buf = append(buf, hex.EncodeToString(src.Bytes)...) + return buf, nil +} + +func (src Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.Bytes...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Bytea) Scan(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + buf := make([]byte, len(src)) + copy(buf, src) + *dst = Bytea{Bytes: buf, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bytea) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/vendor/github.com/jackc/pgtype/bytea_array.go b/vendor/github.com/jackc/pgtype/bytea_array.go new file mode 100644 index 000000000..69d1ceb98 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/bytea_array.go @@ -0,0 +1,489 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type ByteaArray struct { + Elements []Bytea + Dimensions []ArrayDimension + Status Status +} + +func (dst *ByteaArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case [][]byte: + if value == nil { + *dst = ByteaArray{Status: Null} + } else if len(value) == 0 { + *dst = ByteaArray{Status: Present} + } else { + elements := make([]Bytea, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = ByteaArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Bytea: + if value == nil { + *dst = ByteaArray{Status: Null} + } else if len(value) == 0 { + *dst = ByteaArray{Status: Present} + } else { + *dst = ByteaArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = ByteaArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for ByteaArray", src) + } + if elementsLength == 0 { + *dst = ByteaArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to ByteaArray", src) + } + + *dst = ByteaArray{ + Elements: make([]Bytea, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Bytea, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to ByteaArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *ByteaArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to ByteaArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in ByteaArray", err) + } + index++ + + return index, nil +} + +func (dst ByteaArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *ByteaArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *ByteaArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from ByteaArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from ByteaArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Bytea + + if len(uta.Elements) > 0 { + elements = make([]Bytea, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Bytea + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Bytea, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("bytea"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "bytea") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *ByteaArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src ByteaArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/cid.go b/vendor/github.com/jackc/pgtype/cid.go new file mode 100644 index 000000000..b944748c7 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/cid.go @@ -0,0 +1,61 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// CID is PostgreSQL's Command Identifier type. +// +// When one does +// +// select cmin, cmax, * from some_table; +// +// it is the data type of the cmin and cmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/c.h as CommandId +// in the PostgreSQL sources. +type CID pguint32 + +// Set converts from src to dst. Note that as CID is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *CID) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst CID) Get() interface{} { + return (pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as CID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *CID) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *CID) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) +} + +func (dst *CID) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) +} + +func (src CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeText(ci, buf) +} + +func (src CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *CID) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src CID) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/vendor/github.com/jackc/pgtype/cidr.go b/vendor/github.com/jackc/pgtype/cidr.go new file mode 100644 index 000000000..2241ca1c0 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/cidr.go @@ -0,0 +1,31 @@ +package pgtype + +type CIDR Inet + +func (dst *CIDR) Set(src interface{}) error { + return (*Inet)(dst).Set(src) +} + +func (dst CIDR) Get() interface{} { + return (Inet)(dst).Get() +} + +func (src *CIDR) AssignTo(dst interface{}) error { + return (*Inet)(src).AssignTo(dst) +} + +func (dst *CIDR) DecodeText(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeText(ci, src) +} + +func (dst *CIDR) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeBinary(ci, src) +} + +func (src CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Inet)(src).EncodeText(ci, buf) +} + +func (src CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Inet)(src).EncodeBinary(ci, buf) +} diff --git a/vendor/github.com/jackc/pgtype/cidr_array.go b/vendor/github.com/jackc/pgtype/cidr_array.go new file mode 100644 index 000000000..783c599c4 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/cidr_array.go @@ -0,0 +1,546 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "net" + "reflect" + + "github.com/jackc/pgio" +) + +type CIDRArray struct { + Elements []CIDR + Dimensions []ArrayDimension + Status Status +} + +func (dst *CIDRArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = CIDRArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []*net.IPNet: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []net.IP: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*net.IP: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + elements := make([]CIDR, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CIDRArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []CIDR: + if value == nil { + *dst = CIDRArray{Status: Null} + } else if len(value) == 0 { + *dst = CIDRArray{Status: Present} + } else { + *dst = CIDRArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = CIDRArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for CIDRArray", src) + } + if elementsLength == 0 { + *dst = CIDRArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to CIDRArray", src) + } + + *dst = CIDRArray{ + Elements: make([]CIDR, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]CIDR, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to CIDRArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *CIDRArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to CIDRArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in CIDRArray", err) + } + index++ + + return index, nil +} + +func (dst CIDRArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *CIDRArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]*net.IPNet: + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]net.IP: + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.IP: + *v = make([]*net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *CIDRArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from CIDRArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from CIDRArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CIDRArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []CIDR + + if len(uta.Elements) > 0 { + elements = make([]CIDR, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem CIDR + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CIDRArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]CIDR, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("cidr"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "cidr") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *CIDRArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src CIDRArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/circle.go b/vendor/github.com/jackc/pgtype/circle.go new file mode 100644 index 000000000..4279650e3 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/circle.go @@ -0,0 +1,150 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Circle struct { + P Vec2 + R float64 + Status Status +} + +func (dst *Circle) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Circle", src) +} + +func (dst Circle) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Circle) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + if len(src) < 9 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + str := string(src[2:]) + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+2 : len(str)-1] + + r, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Circle{P: Vec2{x, y}, R: r, Status: Present} + return nil +} + +func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + r := binary.BigEndian.Uint64(src[16:]) + + *dst = Circle{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + R: math.Float64frombits(r), + Status: Present, + } + return nil +} + +func (src Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, + strconv.FormatFloat(src.P.X, 'f', -1, 64), + strconv.FormatFloat(src.P.Y, 'f', -1, 64), + strconv.FormatFloat(src.R, 'f', -1, 64), + )...) + + return buf, nil +} + +func (src Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.R)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Circle) Scan(src interface{}) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Circle) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/composite_fields.go b/vendor/github.com/jackc/pgtype/composite_fields.go new file mode 100644 index 000000000..b6d09fcf2 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/composite_fields.go @@ -0,0 +1,107 @@ +package pgtype + +import "fmt" + +// CompositeFields scans the fields of a composite type into the elements of the CompositeFields value. To scan a +// nullable value use a *CompositeFields. It will be set to nil in case of null. +// +// CompositeFields implements EncodeBinary and EncodeText. However, functionality is limited due to CompositeFields not +// knowing the PostgreSQL schema of the composite type. Prefer using a registered CompositeType. +type CompositeFields []interface{} + +func (cf CompositeFields) DecodeBinary(ci *ConnInfo, src []byte) error { + if len(cf) == 0 { + return fmt.Errorf("cannot decode into empty CompositeFields") + } + + if src == nil { + return fmt.Errorf("cannot decode unexpected null into CompositeFields") + } + + scanner := NewCompositeBinaryScanner(ci, src) + + for _, f := range cf { + scanner.ScanValue(f) + } + + if scanner.Err() != nil { + return scanner.Err() + } + + return nil +} + +func (cf CompositeFields) DecodeText(ci *ConnInfo, src []byte) error { + if len(cf) == 0 { + return fmt.Errorf("cannot decode into empty CompositeFields") + } + + if src == nil { + return fmt.Errorf("cannot decode unexpected null into CompositeFields") + } + + scanner := NewCompositeTextScanner(ci, src) + + for _, f := range cf { + scanner.ScanValue(f) + } + + if scanner.Err() != nil { + return scanner.Err() + } + + return nil +} + +// EncodeText encodes composite fields into the text format. Prefer registering a CompositeType to using +// CompositeFields to encode directly. +func (cf CompositeFields) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + b := NewCompositeTextBuilder(ci, buf) + + for _, f := range cf { + if textEncoder, ok := f.(TextEncoder); ok { + b.AppendEncoder(textEncoder) + } else { + b.AppendValue(f) + } + } + + return b.Finish() +} + +// EncodeBinary encodes composite fields into the binary format. Unlike CompositeType the schema of the destination is +// unknown. Prefer registering a CompositeType to using CompositeFields to encode directly. Because the binary +// composite format requires the OID of each field to be specified the only types that will work are those known to +// ConnInfo. +// +// In particular: +// +// * Nil cannot be used because there is no way to determine what type it. +// * Integer types must be exact matches. e.g. A Go int32 into a PostgreSQL bigint will fail. +// * No dereferencing will be done. e.g. *Text must be used instead of Text. +func (cf CompositeFields) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + b := NewCompositeBinaryBuilder(ci, buf) + + for _, f := range cf { + dt, ok := ci.DataTypeForValue(f) + if !ok { + return nil, fmt.Errorf("Unknown OID for %#v", f) + } + + if binaryEncoder, ok := f.(BinaryEncoder); ok { + b.AppendEncoder(dt.OID, binaryEncoder) + } else { + err := dt.Value.Set(f) + if err != nil { + return nil, err + } + if binaryEncoder, ok := dt.Value.(BinaryEncoder); ok { + b.AppendEncoder(dt.OID, binaryEncoder) + } else { + return nil, fmt.Errorf("Cannot encode binary format for %v", f) + } + } + } + + return b.Finish() +} diff --git a/vendor/github.com/jackc/pgtype/composite_type.go b/vendor/github.com/jackc/pgtype/composite_type.go new file mode 100644 index 000000000..32e0aa26b --- /dev/null +++ b/vendor/github.com/jackc/pgtype/composite_type.go @@ -0,0 +1,682 @@ +package pgtype + +import ( + "encoding/binary" + "errors" + "fmt" + "reflect" + "strings" + + "github.com/jackc/pgio" +) + +type CompositeTypeField struct { + Name string + OID uint32 +} + +type CompositeType struct { + status Status + + typeName string + + fields []CompositeTypeField + valueTranscoders []ValueTranscoder +} + +// NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used +// for fields. All field OIDs must be previously registered in ci. +func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) { + valueTranscoders := make([]ValueTranscoder, len(fields)) + + for i := range fields { + dt, ok := ci.DataTypeForOID(fields[i].OID) + if !ok { + return nil, fmt.Errorf("no data type registered for oid: %d", fields[i].OID) + } + + value := NewValue(dt.Value) + valueTranscoder, ok := value.(ValueTranscoder) + if !ok { + return nil, fmt.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID) + } + + valueTranscoders[i] = valueTranscoder + } + + return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil +} + +// NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length. +// Prefer NewCompositeType unless overriding the transcoding of fields is required. +func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) { + if len(fields) != len(values) { + return nil, errors.New("fields and valueTranscoders must have same length") + } + + return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil +} + +func (src CompositeType) Get() interface{} { + switch src.status { + case Present: + results := make(map[string]interface{}, len(src.valueTranscoders)) + for i := range src.valueTranscoders { + results[src.fields[i].Name] = src.valueTranscoders[i].Get() + } + return results + case Null: + return nil + default: + return src.status + } +} + +func (ct *CompositeType) NewTypeValue() Value { + a := &CompositeType{ + typeName: ct.typeName, + fields: ct.fields, + valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)), + } + + for i := range ct.valueTranscoders { + a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder) + } + + return a +} + +func (ct *CompositeType) TypeName() string { + return ct.typeName +} + +func (ct *CompositeType) Fields() []CompositeTypeField { + return ct.fields +} + +func (dst *CompositeType) Set(src interface{}) error { + if src == nil { + dst.status = Null + return nil + } + + switch value := src.(type) { + case []interface{}: + if len(value) != len(dst.valueTranscoders) { + return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders)) + } + for i, v := range value { + if err := dst.valueTranscoders[i].Set(v); err != nil { + return err + } + } + dst.status = Present + case *[]interface{}: + if value == nil { + dst.status = Null + return nil + } + return dst.Set(*value) + default: + return fmt.Errorf("Can not convert %v to Composite", src) + } + + return nil +} + +// AssignTo should never be called on composite value directly +func (src CompositeType) AssignTo(dst interface{}) error { + switch src.status { + case Present: + switch v := dst.(type) { + case []interface{}: + if len(v) != len(src.valueTranscoders) { + return fmt.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders)) + } + for i := range src.valueTranscoders { + if v[i] == nil { + continue + } + + err := assignToOrSet(src.valueTranscoders[i], v[i]) + if err != nil { + return fmt.Errorf("unable to assign to dst[%d]: %v", i, err) + } + } + return nil + case *[]interface{}: + return src.AssignTo(*v) + default: + if isPtrStruct, err := src.assignToPtrStruct(dst); isPtrStruct { + return err + } + + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func assignToOrSet(src Value, dst interface{}) error { + assignToErr := src.AssignTo(dst) + if assignToErr != nil { + // Try to use get / set instead -- this avoids every type having to be able to AssignTo type of self. + setSucceeded := false + if setter, ok := dst.(Value); ok { + err := setter.Set(src.Get()) + setSucceeded = err == nil + } + if !setSucceeded { + return assignToErr + } + } + + return nil +} + +func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) { + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() != reflect.Ptr { + return false, nil + } + + if dstValue.IsNil() { + return false, nil + } + + dstElemValue := dstValue.Elem() + dstElemType := dstElemValue.Type() + + if dstElemType.Kind() != reflect.Struct { + return false, nil + } + + exportedFields := make([]int, 0, dstElemType.NumField()) + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath == "" { + exportedFields = append(exportedFields, i) + } + } + + if len(exportedFields) != len(src.valueTranscoders) { + return false, nil + } + + for i := range exportedFields { + err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface()) + if err != nil { + return true, fmt.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err) + } + } + + return true, nil +} + +func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + switch src.status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + b := NewCompositeBinaryBuilder(ci, buf) + for i := range src.valueTranscoders { + b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i]) + } + + return b.Finish() +} + +// DecodeBinary implements BinaryDecoder interface. +// Opposite to Record, fields in a composite act as a "schema" +// and decoding fails if SQL value can't be assigned due to +// type mismatch +func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { + if buf == nil { + dst.status = Null + return nil + } + + scanner := NewCompositeBinaryScanner(ci, buf) + + for _, f := range dst.valueTranscoders { + scanner.ScanDecoder(f) + } + + if scanner.Err() != nil { + return scanner.Err() + } + + dst.status = Present + + return nil +} + +func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { + if buf == nil { + dst.status = Null + return nil + } + + scanner := NewCompositeTextScanner(ci, buf) + + for _, f := range dst.valueTranscoders { + scanner.ScanDecoder(f) + } + + if scanner.Err() != nil { + return scanner.Err() + } + + dst.status = Present + + return nil +} + +func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) { + switch src.status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + b := NewCompositeTextBuilder(ci, buf) + for _, f := range src.valueTranscoders { + b.AppendEncoder(f) + } + + return b.Finish() +} + +type CompositeBinaryScanner struct { + ci *ConnInfo + rp int + src []byte + + fieldCount int32 + fieldBytes []byte + fieldOID uint32 + err error +} + +// NewCompositeBinaryScanner a scanner over a binary encoded composite balue. +func NewCompositeBinaryScanner(ci *ConnInfo, src []byte) *CompositeBinaryScanner { + rp := 0 + if len(src[rp:]) < 4 { + return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)} + } + + fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + return &CompositeBinaryScanner{ + ci: ci, + rp: rp, + src: src, + fieldCount: fieldCount, + } +} + +// ScanDecoder calls Next and decodes the result with d. +func (cfs *CompositeBinaryScanner) ScanDecoder(d BinaryDecoder) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = d.DecodeBinary(cfs.ci, cfs.fieldBytes) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// ScanDecoder calls Next and scans the result into d. +func (cfs *CompositeBinaryScanner) ScanValue(d interface{}) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = cfs.ci.Scan(cfs.OID(), BinaryFormatCode, cfs.Bytes(), d) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeBinaryScanner) Next() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + if len(cfs.src[cfs.rp:]) < 8 { + cfs.err = fmt.Errorf("Record incomplete %v", cfs.src) + return false + } + cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:]) + cfs.rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:]))) + cfs.rp += 4 + + if fieldLen >= 0 { + if len(cfs.src[cfs.rp:]) < fieldLen { + cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) + return false + } + cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen] + cfs.rp += fieldLen + } else { + cfs.fieldBytes = nil + } + + return true +} + +func (cfs *CompositeBinaryScanner) FieldCount() int { + return int(cfs.fieldCount) +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// OID returns the OID of the field most recently read by Scan(). +func (cfs *CompositeBinaryScanner) OID() uint32 { + return cfs.fieldOID +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeBinaryScanner) Err() error { + return cfs.err +} + +type CompositeTextScanner struct { + ci *ConnInfo + rp int + src []byte + + fieldBytes []byte + err error +} + +// NewCompositeTextScanner a scanner over a text encoded composite value. +func NewCompositeTextScanner(ci *ConnInfo, src []byte) *CompositeTextScanner { + if len(src) < 2 { + return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)} + } + + if src[0] != '(' { + return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")} + } + + if src[len(src)-1] != ')' { + return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")} + } + + return &CompositeTextScanner{ + ci: ci, + rp: 1, + src: src, + } +} + +// ScanDecoder calls Next and decodes the result with d. +func (cfs *CompositeTextScanner) ScanDecoder(d TextDecoder) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = d.DecodeText(cfs.ci, cfs.fieldBytes) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// ScanDecoder calls Next and scans the result into d. +func (cfs *CompositeTextScanner) ScanValue(d interface{}) { + if cfs.err != nil { + return + } + + if cfs.Next() { + cfs.err = cfs.ci.Scan(0, TextFormatCode, cfs.Bytes(), d) + } else { + cfs.err = errors.New("read past end of composite") + } +} + +// Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After +// Next returns false, the Err method can be called to check if any errors occurred. +func (cfs *CompositeTextScanner) Next() bool { + if cfs.err != nil { + return false + } + + if cfs.rp == len(cfs.src) { + return false + } + + switch cfs.src[cfs.rp] { + case ',', ')': // null + cfs.rp++ + cfs.fieldBytes = nil + return true + case '"': // quoted value + cfs.rp++ + cfs.fieldBytes = make([]byte, 0, 16) + for { + ch := cfs.src[cfs.rp] + + if ch == '"' { + cfs.rp++ + if cfs.src[cfs.rp] == '"' { + cfs.fieldBytes = append(cfs.fieldBytes, '"') + cfs.rp++ + } else { + break + } + } else if ch == '\\' { + cfs.rp++ + cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp]) + cfs.rp++ + } else { + cfs.fieldBytes = append(cfs.fieldBytes, ch) + cfs.rp++ + } + } + cfs.rp++ + return true + default: // unquoted value + start := cfs.rp + for { + ch := cfs.src[cfs.rp] + if ch == ',' || ch == ')' { + break + } + cfs.rp++ + } + cfs.fieldBytes = cfs.src[start:cfs.rp] + cfs.rp++ + return true + } +} + +// Bytes returns the bytes of the field most recently read by Scan(). +func (cfs *CompositeTextScanner) Bytes() []byte { + return cfs.fieldBytes +} + +// Err returns any error encountered by the scanner. +func (cfs *CompositeTextScanner) Err() error { + return cfs.err +} + +type CompositeBinaryBuilder struct { + ci *ConnInfo + buf []byte + startIdx int + fieldCount uint32 + err error +} + +func NewCompositeBinaryBuilder(ci *ConnInfo, buf []byte) *CompositeBinaryBuilder { + startIdx := len(buf) + buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields + return &CompositeBinaryBuilder{ci: ci, buf: buf, startIdx: startIdx} +} + +func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field interface{}) { + if b.err != nil { + return + } + + dt, ok := b.ci.DataTypeForOID(oid) + if !ok { + b.err = fmt.Errorf("unknown data type for OID: %d", oid) + return + } + + err := dt.Value.Set(field) + if err != nil { + b.err = err + return + } + + binaryEncoder, ok := dt.Value.(BinaryEncoder) + if !ok { + b.err = fmt.Errorf("unable to encode binary for OID: %d", oid) + return + } + + b.AppendEncoder(oid, binaryEncoder) +} + +func (b *CompositeBinaryBuilder) AppendEncoder(oid uint32, field BinaryEncoder) { + if b.err != nil { + return + } + + b.buf = pgio.AppendUint32(b.buf, oid) + lengthPos := len(b.buf) + b.buf = pgio.AppendInt32(b.buf, -1) + fieldBuf, err := field.EncodeBinary(b.ci, b.buf) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) + b.buf = fieldBuf + } + + b.fieldCount++ +} + +func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount) + return b.buf, nil +} + +type CompositeTextBuilder struct { + ci *ConnInfo + buf []byte + startIdx int + fieldCount uint32 + err error + fieldBuf [32]byte +} + +func NewCompositeTextBuilder(ci *ConnInfo, buf []byte) *CompositeTextBuilder { + buf = append(buf, '(') // allocate room for number of fields + return &CompositeTextBuilder{ci: ci, buf: buf} +} + +func (b *CompositeTextBuilder) AppendValue(field interface{}) { + if b.err != nil { + return + } + + if field == nil { + b.buf = append(b.buf, ',') + return + } + + dt, ok := b.ci.DataTypeForValue(field) + if !ok { + b.err = fmt.Errorf("unknown data type for field: %v", field) + return + } + + err := dt.Value.Set(field) + if err != nil { + b.err = err + return + } + + textEncoder, ok := dt.Value.(TextEncoder) + if !ok { + b.err = fmt.Errorf("unable to encode text for value: %v", field) + return + } + + b.AppendEncoder(textEncoder) +} + +func (b *CompositeTextBuilder) AppendEncoder(field TextEncoder) { + if b.err != nil { + return + } + + fieldBuf, err := field.EncodeText(b.ci, b.fieldBuf[0:0]) + if err != nil { + b.err = err + return + } + if fieldBuf != nil { + b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...) + } + + b.buf = append(b.buf, ',') +} + +func (b *CompositeTextBuilder) Finish() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + + b.buf[len(b.buf)-1] = ')' + return b.buf, nil +} + +var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteCompositeField(src string) string { + return `"` + quoteCompositeReplacer.Replace(src) + `"` +} + +func quoteCompositeFieldIfNeeded(src string) string { + if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) { + return quoteCompositeField(src) + } + return src +} diff --git a/vendor/github.com/jackc/pgtype/convert.go b/vendor/github.com/jackc/pgtype/convert.go new file mode 100644 index 000000000..de9ba9ba3 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/convert.go @@ -0,0 +1,472 @@ +package pgtype + +import ( + "database/sql" + "fmt" + "math" + "reflect" + "time" +) + +const ( + maxUint = ^uint(0) + maxInt = int(maxUint >> 1) + minInt = -maxInt - 1 +) + +// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 +func underlyingNumberType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Int: + convVal := int(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int8: + convVal := int8(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int16: + convVal := int16(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int32: + convVal := int32(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int64: + convVal := int64(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint: + convVal := uint(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint8: + convVal := uint8(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint16: + convVal := uint16(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint32: + convVal := uint32(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint64: + convVal := uint64(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float32: + convVal := float32(refVal.Float()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float64: + convVal := refVal.Float() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBoolType gets the underlying type that can be converted to Bool +func underlyingBoolType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Bool: + convVal := refVal.Bool() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBytesType gets the underlying type that can be converted to []byte +func underlyingBytesType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + if refVal.Type().Elem().Kind() == reflect.Uint8 { + convVal := refVal.Bytes() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + } + + return nil, false +} + +// underlyingStringType gets the underlying type that can be converted to String +func underlyingStringType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingPtrType dereferences a pointer +func underlyingPtrType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + return nil, false +} + +// underlyingTimeType gets the underlying type that can be converted to time.Time +func underlyingTimeType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + timeType := reflect.TypeOf(time.Time{}) + if refVal.Type().ConvertibleTo(timeType) { + return refVal.Convert(timeType).Interface(), true + } + + return nil, false +} + +// underlyingUUIDType gets the underlying type that can be converted to [16]byte +func underlyingUUIDType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + uuidType := reflect.TypeOf([16]byte{}) + if refVal.Type().ConvertibleTo(uuidType) { + return refVal.Convert(uuidType).Interface(), true + } + + return nil, false +} + +// underlyingSliceType gets the underlying slice type +func underlyingSliceType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + baseSliceType := reflect.SliceOf(refVal.Type().Elem()) + if refVal.Type().ConvertibleTo(baseSliceType) { + convVal := refVal.Convert(baseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + } + + return nil, false +} + +func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { + if srcStatus == Present { + switch v := dst.(type) { + case *int: + if srcVal < int64(minInt) { + return fmt.Errorf("%d is less than minimum value for int", srcVal) + } else if srcVal > int64(maxInt) { + return fmt.Errorf("%d is greater than maximum value for int", srcVal) + } + *v = int(srcVal) + case *int8: + if srcVal < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", srcVal) + } else if srcVal > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", srcVal) + } + *v = int8(srcVal) + case *int16: + if srcVal < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", srcVal) + } else if srcVal > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", srcVal) + } + *v = int16(srcVal) + case *int32: + if srcVal < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", srcVal) + } else if srcVal > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", srcVal) + } + *v = int32(srcVal) + case *int64: + if srcVal < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for int64", srcVal) + } else if srcVal > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for int64", srcVal) + } + *v = int64(srcVal) + case *uint: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint", srcVal) + } else if uint64(srcVal) > uint64(maxUint) { + return fmt.Errorf("%d is greater than maximum value for uint", srcVal) + } + *v = uint(srcVal) + case *uint8: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint8", srcVal) + } else if srcVal > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) + } + *v = uint8(srcVal) + case *uint16: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) + } + *v = uint16(srcVal) + case *uint32: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) + } + *v = uint32(srcVal) + case *uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint64", srcVal) + } + *v = uint64(srcVal) + case sql.Scanner: + return v.Scan(srcVal) + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return int64AssignTo(srcVal, srcStatus, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if el.OverflowInt(int64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetInt(int64(srcVal)) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for %T", srcVal, dst) + } + if el.OverflowUint(uint64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetUint(uint64(srcVal)) + return nil + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Present, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) +} + +func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { + if srcStatus == Present { + switch v := dst.(type) { + case *float32: + *v = float32(srcVal) + case *float64: + *v = srcVal + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return float64AssignTo(srcVal, srcStatus, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i64 := int64(srcVal) + if float64(i64) == srcVal { + return int64AssignTo(i64, srcStatus, dst) + } + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Present, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) +} + +func NullAssignTo(dst interface{}) error { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return &nullAssignmentError{dst: dst} + } + + dstVal := dstPtr.Elem() + + switch dstVal.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + dstVal.Set(reflect.Zero(dstVal.Type())) + return nil + } + + return &nullAssignmentError{dst: dst} +} + +var kindTypes map[reflect.Kind]reflect.Type + +func toInterface(dst reflect.Value, t reflect.Type) (interface{}, bool) { + nextDst := dst.Convert(t) + return nextDst.Interface(), dst.Type() != nextDst.Type() +} + +// GetAssignToDstType attempts to convert dst to something AssignTo can assign +// to. If dst is a pointer to pointer it allocates a value and returns the +// dereferences pointer. If dst is a named type such as *Foo where Foo is type +// Foo int16, it converts dst to *int16. +// +// GetAssignToDstType returns the converted dst and a bool representing if any +// change was made. +func GetAssignToDstType(dst interface{}) (interface{}, bool) { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return nil, false + } + + dstVal := dstPtr.Elem() + + // if dst is a pointer to pointer, allocate space try again with the dereferenced pointer + if dstVal.Kind() == reflect.Ptr { + dstVal.Set(reflect.New(dstVal.Type().Elem())) + return dstVal.Interface(), true + } + + // if dst is pointer to a base type that has been renamed + if baseValType, ok := kindTypes[dstVal.Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(baseValType)) + } + + if dstVal.Kind() == reflect.Slice { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(reflect.SliceOf(baseElemType))) + } + } + + if dstVal.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType))) + } + } + + if dstVal.Kind() == reflect.Struct { + if dstVal.Type().NumField() == 1 && dstVal.Type().Field(0).Anonymous { + dstPtr = dstVal.Field(0).Addr() + nested := dstVal.Type().Field(0).Type + if nested.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[nested.Elem().Kind()]; ok { + return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType))) + } + } + if _, ok := kindTypes[nested.Kind()]; ok && dstPtr.CanInterface() { + return dstPtr.Interface(), true + } + } + } + + return nil, false +} + +func init() { + kindTypes = map[reflect.Kind]reflect.Type{ + reflect.Bool: reflect.TypeOf(false), + reflect.Float32: reflect.TypeOf(float32(0)), + reflect.Float64: reflect.TypeOf(float64(0)), + reflect.Int: reflect.TypeOf(int(0)), + reflect.Int8: reflect.TypeOf(int8(0)), + reflect.Int16: reflect.TypeOf(int16(0)), + reflect.Int32: reflect.TypeOf(int32(0)), + reflect.Int64: reflect.TypeOf(int64(0)), + reflect.Uint: reflect.TypeOf(uint(0)), + reflect.Uint8: reflect.TypeOf(uint8(0)), + reflect.Uint16: reflect.TypeOf(uint16(0)), + reflect.Uint32: reflect.TypeOf(uint32(0)), + reflect.Uint64: reflect.TypeOf(uint64(0)), + reflect.String: reflect.TypeOf(""), + } +} diff --git a/vendor/github.com/jackc/pgtype/database_sql.go b/vendor/github.com/jackc/pgtype/database_sql.go new file mode 100644 index 000000000..9d1cf8226 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/database_sql.go @@ -0,0 +1,41 @@ +package pgtype + +import ( + "database/sql/driver" + "errors" +) + +func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { + if valuer, ok := src.(driver.Valuer); ok { + return valuer.Value() + } + + if textEncoder, ok := src.(TextEncoder); ok { + buf, err := textEncoder.EncodeText(ci, nil) + if err != nil { + return nil, err + } + return string(buf), nil + } + + if binaryEncoder, ok := src.(BinaryEncoder); ok { + buf, err := binaryEncoder.EncodeBinary(ci, nil) + if err != nil { + return nil, err + } + return buf, nil + } + + return nil, errors.New("cannot convert to database/sql compatible value") +} + +func EncodeValueText(src TextEncoder) (interface{}, error) { + buf, err := src.EncodeText(nil, make([]byte, 0, 32)) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), err +} diff --git a/vendor/github.com/jackc/pgtype/date.go b/vendor/github.com/jackc/pgtype/date.go new file mode 100644 index 000000000..e8d21a78c --- /dev/null +++ b/vendor/github.com/jackc/pgtype/date.go @@ -0,0 +1,287 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "time" + + "github.com/jackc/pgio" +) + +type Date struct { + Time time.Time + Status Status + InfinityModifier InfinityModifier +} + +const ( + negativeInfinityDayOffset = -2147483648 + infinityDayOffset = 2147483647 +) + +func (dst *Date) Set(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case time.Time: + *dst = Date{Time: value, Status: Present} + case string: + return dst.DecodeText(nil, []byte(value)) + case *time.Time: + if value == nil { + *dst = Date{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Date{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Date", value) + } + + return nil +} + +func (dst Date) Get() interface{} { + switch dst.Status { + case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Date) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + sbuf := string(src) + switch sbuf { + case "infinity": + *dst = Date{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Date{Status: Present, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) + if err != nil { + return err + } + + *dst = Date{Time: t, Status: Present} + } + + return nil +} + +func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for date: %v", len(src)) + } + + dayOffset := int32(binary.BigEndian.Uint32(src)) + + switch dayOffset { + case infinityDayOffset: + *dst = Date{Status: Present, InfinityModifier: Infinity} + case negativeInfinityDayOffset: + *dst = Date{Status: Present, InfinityModifier: -Infinity} + default: + t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) + *dst = Date{Time: t, Status: Present} + } + + return nil +} + +func (src Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return append(buf, s...), nil +} + +func (src Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var daysSinceDateEpoch int32 + switch src.InfinityModifier { + case None: + tUnix := time.Date(src.Time.Year(), src.Time.Month(), src.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch = int32(secSinceDateEpoch / 86400) + case Infinity: + daysSinceDateEpoch = infinityDayOffset + case NegativeInfinity: + daysSinceDateEpoch = negativeInfinityDayOffset + } + + return pgio.AppendInt32(buf, daysSinceDateEpoch), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Date) Scan(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + *dst = Date{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Date) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src Date) MarshalJSON() ([]byte, error) { + switch src.Status { + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + if src.Status != Present { + return nil, errBadStatus + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (dst *Date) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Date{Status: Null} + return nil + } + + switch *s { + case "infinity": + *dst = Date{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Date{Status: Present, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", *s, time.UTC) + if err != nil { + return err + } + + *dst = Date{Time: t, Status: Present} + } + + return nil +} diff --git a/vendor/github.com/jackc/pgtype/date_array.go b/vendor/github.com/jackc/pgtype/date_array.go new file mode 100644 index 000000000..24152fa0e --- /dev/null +++ b/vendor/github.com/jackc/pgtype/date_array.go @@ -0,0 +1,518 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + "time" + + "github.com/jackc/pgio" +) + +type DateArray struct { + Elements []Date + Dimensions []ArrayDimension + Status Status +} + +func (dst *DateArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = DateArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []time.Time: + if value == nil { + *dst = DateArray{Status: Null} + } else if len(value) == 0 { + *dst = DateArray{Status: Present} + } else { + elements := make([]Date, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = DateArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*time.Time: + if value == nil { + *dst = DateArray{Status: Null} + } else if len(value) == 0 { + *dst = DateArray{Status: Present} + } else { + elements := make([]Date, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = DateArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Date: + if value == nil { + *dst = DateArray{Status: Null} + } else if len(value) == 0 { + *dst = DateArray{Status: Present} + } else { + *dst = DateArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = DateArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for DateArray", src) + } + if elementsLength == 0 { + *dst = DateArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to DateArray", src) + } + + *dst = DateArray{ + Elements: make([]Date, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Date, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to DateArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *DateArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to DateArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in DateArray", err) + } + index++ + + return index, nil +} + +func (dst DateArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *DateArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *DateArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from DateArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from DateArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = DateArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Date + + if len(uta.Elements) > 0 { + elements = make([]Date, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Date + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = DateArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = DateArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = DateArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Date, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = DateArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("date"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "date") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *DateArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src DateArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/daterange.go b/vendor/github.com/jackc/pgtype/daterange.go new file mode 100644 index 000000000..63164a5a5 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/daterange.go @@ -0,0 +1,267 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Daterange struct { + Lower Date + Upper Date + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Daterange) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + switch value := src.(type) { + case Daterange: + *dst = value + case *Daterange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Daterange", src) + } + + return nil +} + +func (dst Daterange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Daterange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Daterange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Daterange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Daterange) Scan(src interface{}) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Daterange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/enum_array.go b/vendor/github.com/jackc/pgtype/enum_array.go new file mode 100644 index 000000000..59b5a3edc --- /dev/null +++ b/vendor/github.com/jackc/pgtype/enum_array.go @@ -0,0 +1,428 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "fmt" + "reflect" +) + +type EnumArray struct { + Elements []GenericText + Dimensions []ArrayDimension + Status Status +} + +func (dst *EnumArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = EnumArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = EnumArray{Status: Null} + } else if len(value) == 0 { + *dst = EnumArray{Status: Present} + } else { + elements := make([]GenericText, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = EnumArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = EnumArray{Status: Null} + } else if len(value) == 0 { + *dst = EnumArray{Status: Present} + } else { + elements := make([]GenericText, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = EnumArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []GenericText: + if value == nil { + *dst = EnumArray{Status: Null} + } else if len(value) == 0 { + *dst = EnumArray{Status: Present} + } else { + *dst = EnumArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = EnumArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for EnumArray", src) + } + if elementsLength == 0 { + *dst = EnumArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to EnumArray", src) + } + + *dst = EnumArray{ + Elements: make([]GenericText, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]GenericText, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to EnumArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *EnumArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to EnumArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in EnumArray", err) + } + index++ + + return index, nil +} + +func (dst EnumArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *EnumArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *EnumArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from EnumArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from EnumArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = EnumArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []GenericText + + if len(uta.Elements) > 0 { + elements = make([]GenericText, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem GenericText + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = EnumArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (src EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *EnumArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src EnumArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/enum_type.go b/vendor/github.com/jackc/pgtype/enum_type.go new file mode 100644 index 000000000..d340320fa --- /dev/null +++ b/vendor/github.com/jackc/pgtype/enum_type.go @@ -0,0 +1,168 @@ +package pgtype + +import "fmt" + +// EnumType represents a enum type. While it implements Value, this is only in service of its type conversion duties +// when registered as a data type in a ConnType. It should not be used directly as a Value. +type EnumType struct { + value string + status Status + + typeName string // PostgreSQL type name + members []string // enum members + membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating +} + +// NewEnumType initializes a new EnumType. It retains a read-only reference to members. members must not be changed. +func NewEnumType(typeName string, members []string) *EnumType { + et := &EnumType{typeName: typeName, members: members} + et.membersMap = make(map[string]string, len(members)) + for _, m := range members { + et.membersMap[m] = m + } + return et +} + +func (et *EnumType) NewTypeValue() Value { + return &EnumType{ + value: et.value, + status: et.status, + + typeName: et.typeName, + members: et.members, + membersMap: et.membersMap, + } +} + +func (et *EnumType) TypeName() string { + return et.typeName +} + +func (et *EnumType) Members() []string { + return et.members +} + +// Set assigns src to dst. Set purposely does not check that src is a member. This allows continued error free +// operation in the event the PostgreSQL enum type is modified during a connection. +func (dst *EnumType) Set(src interface{}) error { + if src == nil { + dst.status = Null + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case string: + dst.value = value + dst.status = Present + case *string: + if value == nil { + dst.status = Null + } else { + dst.value = *value + dst.status = Present + } + case []byte: + if value == nil { + dst.status = Null + } else { + dst.value = string(value) + dst.status = Present + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to enum %s", value, dst.typeName) + } + + return nil +} + +func (dst EnumType) Get() interface{} { + switch dst.status { + case Present: + return dst.value + case Null: + return nil + default: + return dst.status + } +} + +func (src *EnumType) AssignTo(dst interface{}) error { + switch src.status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.value + return nil + case *[]byte: + *v = make([]byte, len(src.value)) + copy(*v, src.value) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (EnumType) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *EnumType) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + dst.status = Null + return nil + } + + // Lookup the string in membersMap to avoid an allocation. + if s, found := dst.membersMap[string(src)]; found { + dst.value = s + } else { + // If an enum type is modified after the initial connection it is possible to receive an unexpected value. + // Gracefully handle this situation. Purposely NOT modifying members and membersMap to allow for sharing members + // and membersMap between connections. + dst.value = string(src) + } + dst.status = Present + + return nil +} + +func (dst *EnumType) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) +} + +func (EnumType) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src EnumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.value...), nil +} + +func (src EnumType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) +} diff --git a/vendor/github.com/jackc/pgtype/float4.go b/vendor/github.com/jackc/pgtype/float4.go new file mode 100644 index 000000000..89b9e8fae --- /dev/null +++ b/vendor/github.com/jackc/pgtype/float4.go @@ -0,0 +1,282 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Float4 struct { + Float float32 + Status Status +} + +func (dst *Float4) Set(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case float32: + *dst = Float4{Float: value, Status: Present} + case float64: + *dst = Float4{Float: float32(value), Status: Present} + case int8: + *dst = Float4{Float: float32(value), Status: Present} + case uint8: + *dst = Float4{Float: float32(value), Status: Present} + case int16: + *dst = Float4{Float: float32(value), Status: Present} + case uint16: + *dst = Float4{Float: float32(value), Status: Present} + case int32: + f32 := float32(value) + if int32(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint32: + f32 := float32(value) + if uint32(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case int64: + f32 := float32(value) + if int64(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint64: + f32 := float32(value) + if uint64(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case int: + f32 := float32(value) + if int(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint: + f32 := float32(value) + if uint(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case string: + num, err := strconv.ParseFloat(value, 32) + if err != nil { + return err + } + *dst = Float4{Float: float32(num), Status: Present} + case *float64: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *int8: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Float4{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (dst Float4) Get() interface{} { + switch dst.Status { + case Present: + return dst.Float + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Float4) AssignTo(dst interface{}) error { + return float64AssignTo(float64(src.Float), src.Status, dst) +} + +func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + n, err := strconv.ParseFloat(string(src), 32) + if err != nil { + return err + } + + *dst = Float4{Float: float32(n), Status: Present} + return nil +} + +func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + n := int32(binary.BigEndian.Uint32(src)) + + *dst = Float4{Float: math.Float32frombits(uint32(n)), Status: Present} + return nil +} + +func (src Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)...) + return buf, nil +} + +func (src Float4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint32(buf, math.Float32bits(src.Float)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float4) Scan(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float4{Float: float32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return float64(src.Float), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/vendor/github.com/jackc/pgtype/float4_array.go b/vendor/github.com/jackc/pgtype/float4_array.go new file mode 100644 index 000000000..41f2ec8f4 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/float4_array.go @@ -0,0 +1,517 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type Float4Array struct { + Elements []Float4 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Float4Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Float4Array{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []float32: + if value == nil { + *dst = Float4Array{Status: Null} + } else if len(value) == 0 { + *dst = Float4Array{Status: Present} + } else { + elements := make([]Float4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*float32: + if value == nil { + *dst = Float4Array{Status: Null} + } else if len(value) == 0 { + *dst = Float4Array{Status: Present} + } else { + elements := make([]Float4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Float4: + if value == nil { + *dst = Float4Array{Status: Null} + } else if len(value) == 0 { + *dst = Float4Array{Status: Present} + } else { + *dst = Float4Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Float4Array{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for Float4Array", src) + } + if elementsLength == 0 { + *dst = Float4Array{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float4Array", src) + } + + *dst = Float4Array{ + Elements: make([]Float4, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float4, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to Float4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *Float4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to Float4Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in Float4Array", err) + } + index++ + + return index, nil +} + +func (dst Float4Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Float4Array) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float32: + *v = make([]*float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *Float4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from Float4Array") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from Float4Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4Array{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Float4 + + if len(uta.Elements) > 0 { + elements = make([]Float4, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Float4 + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Float4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float4Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Float4Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Float4, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Float4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("float4"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "float4") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float4Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/float8.go b/vendor/github.com/jackc/pgtype/float8.go new file mode 100644 index 000000000..4d9e7116a --- /dev/null +++ b/vendor/github.com/jackc/pgtype/float8.go @@ -0,0 +1,272 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Float8 struct { + Float float64 + Status Status +} + +func (dst *Float8) Set(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case float32: + *dst = Float8{Float: float64(value), Status: Present} + case float64: + *dst = Float8{Float: value, Status: Present} + case int8: + *dst = Float8{Float: float64(value), Status: Present} + case uint8: + *dst = Float8{Float: float64(value), Status: Present} + case int16: + *dst = Float8{Float: float64(value), Status: Present} + case uint16: + *dst = Float8{Float: float64(value), Status: Present} + case int32: + *dst = Float8{Float: float64(value), Status: Present} + case uint32: + *dst = Float8{Float: float64(value), Status: Present} + case int64: + f64 := float64(value) + if int64(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case uint64: + f64 := float64(value) + if uint64(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case int: + f64 := float64(value) + if int(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case uint: + f64 := float64(value) + if uint(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case string: + num, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + *dst = Float8{Float: float64(num), Status: Present} + case *float64: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *int8: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Float8{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (dst Float8) Get() interface{} { + switch dst.Status { + case Present: + return dst.Float + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Float8) AssignTo(dst interface{}) error { + return float64AssignTo(src.Float, src.Status, dst) +} + +func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + n, err := strconv.ParseFloat(string(src), 64) + if err != nil { + return err + } + + *dst = Float8{Float: n, Status: Present} + return nil +} + +func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for float4: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + + *dst = Float8{Float: math.Float64frombits(uint64(n)), Status: Present} + return nil +} + +func (src Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)...) + return buf, nil +} + +func (src Float8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.Float)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float8) Scan(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float8{Float: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Float, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/vendor/github.com/jackc/pgtype/float8_array.go b/vendor/github.com/jackc/pgtype/float8_array.go new file mode 100644 index 000000000..836ee19dc --- /dev/null +++ b/vendor/github.com/jackc/pgtype/float8_array.go @@ -0,0 +1,517 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type Float8Array struct { + Elements []Float8 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Float8Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Float8Array{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []float64: + if value == nil { + *dst = Float8Array{Status: Null} + } else if len(value) == 0 { + *dst = Float8Array{Status: Present} + } else { + elements := make([]Float8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*float64: + if value == nil { + *dst = Float8Array{Status: Null} + } else if len(value) == 0 { + *dst = Float8Array{Status: Present} + } else { + elements := make([]Float8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Float8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Float8: + if value == nil { + *dst = Float8Array{Status: Null} + } else if len(value) == 0 { + *dst = Float8Array{Status: Present} + } else { + *dst = Float8Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Float8Array{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for Float8Array", src) + } + if elementsLength == 0 { + *dst = Float8Array{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8Array", src) + } + + *dst = Float8Array{ + Elements: make([]Float8, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Float8, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to Float8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *Float8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to Float8Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in Float8Array", err) + } + index++ + + return index, nil +} + +func (dst Float8Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Float8Array) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float64: + *v = make([]*float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *Float8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from Float8Array") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from Float8Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8Array{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Float8 + + if len(uta.Elements) > 0 { + elements = make([]Float8, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Float8 + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Float8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Float8Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Float8Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Float8, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Float8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("float8"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "float8") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Float8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float8Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/generic_binary.go b/vendor/github.com/jackc/pgtype/generic_binary.go new file mode 100644 index 000000000..76a1d3511 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/generic_binary.go @@ -0,0 +1,39 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// GenericBinary is a placeholder for binary format values that no other type exists +// to handle. +type GenericBinary Bytea + +func (dst *GenericBinary) Set(src interface{}) error { + return (*Bytea)(dst).Set(src) +} + +func (dst GenericBinary) Get() interface{} { + return (Bytea)(dst).Get() +} + +func (src *GenericBinary) AssignTo(dst interface{}) error { + return (*Bytea)(src).AssignTo(dst) +} + +func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Bytea)(dst).DecodeBinary(ci, src) +} + +func (src GenericBinary) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Bytea)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *GenericBinary) Scan(src interface{}) error { + return (*Bytea)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericBinary) Value() (driver.Value, error) { + return (Bytea)(src).Value() +} diff --git a/vendor/github.com/jackc/pgtype/generic_text.go b/vendor/github.com/jackc/pgtype/generic_text.go new file mode 100644 index 000000000..dbf5b47e8 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/generic_text.go @@ -0,0 +1,39 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// GenericText is a placeholder for text format values that no other type exists +// to handle. +type GenericText Text + +func (dst *GenericText) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst GenericText) Get() interface{} { + return (Text)(dst).Get() +} + +func (src *GenericText) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (src GenericText) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *GenericText) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericText) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/vendor/github.com/jackc/pgtype/go.mod b/vendor/github.com/jackc/pgtype/go.mod new file mode 100644 index 000000000..63bae8798 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/go.mod @@ -0,0 +1,13 @@ +module github.com/jackc/pgtype + +go 1.13 + +require ( + github.com/gofrs/uuid v4.0.0+incompatible + github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530 + github.com/jackc/pgio v1.0.0 + github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c + github.com/lib/pq v1.10.2 + github.com/shopspring/decimal v1.2.0 + github.com/stretchr/testify v1.7.0 +) diff --git a/vendor/github.com/jackc/pgtype/go.sum b/vendor/github.com/jackc/pgtype/go.sum new file mode 100644 index 000000000..8f2d760e4 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/go.sum @@ -0,0 +1,175 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530 h1:dUJ578zuPEsXjtzOfEF0q9zDAfljJ9oFnTHcQaNkccw= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c h1:Dznn52SgVIVst9UyOT9brctYUgxs+CvVfPaC3jKrA50= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/vendor/github.com/jackc/pgtype/hstore.go b/vendor/github.com/jackc/pgtype/hstore.go new file mode 100644 index 000000000..18b413c6b --- /dev/null +++ b/vendor/github.com/jackc/pgtype/hstore.go @@ -0,0 +1,439 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "strings" + "unicode" + "unicode/utf8" + + "github.com/jackc/pgio" +) + +// Hstore represents an hstore column that can be null or have null values +// associated with its keys. +type Hstore struct { + Map map[string]Text + Status Status +} + +func (dst *Hstore) Set(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case map[string]string: + m := make(map[string]Text, len(value)) + for k, v := range value { + m[k] = Text{String: v, Status: Present} + } + *dst = Hstore{Map: m, Status: Present} + default: + return fmt.Errorf("cannot convert %v to Hstore", src) + } + + return nil +} + +func (dst Hstore) Get() interface{} { + switch dst.Status { + case Present: + return dst.Map + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Hstore) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *map[string]string: + *v = make(map[string]string, len(src.Map)) + for k, val := range src.Map { + if val.Status != Present { + return fmt.Errorf("cannot decode %#v into %T", src, dst) + } + (*v)[k] = val.String + } + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + keys, values, err := parseHstore(string(src)) + if err != nil { + return err + } + + m := make(map[string]Text, len(keys)) + for i := range keys { + m[keys[i]] = values[i] + } + + *dst = Hstore{Map: m, Status: Present} + return nil +} + +func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + m := make(map[string]Text, pairCount) + + for i := 0; i < pairCount; i++ { + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + if len(src[rp:]) < keyLen { + return fmt.Errorf("hstore incomplete %v", src) + } + key := string(src[rp : rp+keyLen]) + rp += keyLen + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var valueBuf []byte + if valueLen >= 0 { + valueBuf = src[rp : rp+valueLen] + } + rp += valueLen + + var value Text + err := value.DecodeBinary(ci, valueBuf) + if err != nil { + return err + } + m[key] = value + } + + *dst = Hstore{Map: m, Status: Present} + + return nil +} + +func (src Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + firstPair := true + + inElemBuf := make([]byte, 0, 32) + for k, v := range src.Map { + if firstPair { + firstPair = false + } else { + buf = append(buf, ',') + } + + buf = append(buf, quoteHstoreElementIfNeeded(k)...) + buf = append(buf, "=>"...) + + elemBuf, err := v.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + + if elemBuf == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, quoteHstoreElementIfNeeded(string(elemBuf))...) + } + } + + return buf, nil +} + +func (src Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt32(buf, int32(len(src.Map))) + + var err error + for k, v := range src.Map { + buf = pgio.AppendInt32(buf, int32(len(k))) + buf = append(buf, k...) + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := v.EncodeText(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, err +} + +var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteHstoreElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func quoteHstoreElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { + return quoteArrayElement(src) + } + return src +} + +const ( + hsPre = iota + hsKey + hsSep + hsVal + hsNul + hsNext +) + +type hstoreParser struct { + str string + pos int +} + +func newHSP(in string) *hstoreParser { + return &hstoreParser{ + pos: 0, + str: in, + } +} + +func (p *hstoreParser) Consume() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, w := utf8.DecodeRuneInString(p.str[p.pos:]) + p.pos += w + return +} + +func (p *hstoreParser) Peek() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) + return +} + +// parseHstore parses the string representation of an hstore column (the same +// you would get from an ordinary SELECT) into two slices of keys and values. it +// is used internally in the default parsing of hstores. +func parseHstore(s string) (k []string, v []Text, err error) { + if s == "" { + return + } + + buf := bytes.Buffer{} + keys := []string{} + values := []Text{} + p := newHSP(s) + + r, end := p.Consume() + state := hsPre + + for !end { + switch state { + case hsPre: + if r == '"' { + state = hsKey + } else { + err = errors.New("String does not begin with \"") + } + case hsKey: + switch r { + case '"': //End of the key + keys = append(keys, buf.String()) + buf = bytes.Buffer{} + state = hsSep + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsSep: + if r == '=' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=', expecting '>'") + case r == '>': + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") + case r == '"': + state = hsVal + case r == 'N': + state = hsNul + default: + err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) + } + default: + err = fmt.Errorf("Invalid character after '=', expecting '>'") + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) + } + case hsVal: + switch r { + case '"': //End of the value + values = append(values, Text{String: buf.String(), Status: Present}) + buf = bytes.Buffer{} + state = hsNext + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsNul: + nulBuf := make([]rune, 3) + nulBuf[0] = r + for i := 1; i < 3; i++ { + r, end = p.Consume() + if end { + err = errors.New("Found EOS in NULL value") + return + } + nulBuf[i] = r + } + if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { + values = append(values, Text{Status: Null}) + state = hsNext + } else { + err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + } + case hsNext: + if r == ',' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after ',', expcting space") + case (unicode.IsSpace(r)): + r, end = p.Consume() + state = hsKey + default: + err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + } + } + + if err != nil { + return + } + r, end = p.Consume() + } + if state != hsNext { + err = errors.New("Improperly formatted hstore") + return + } + k = keys + v = values + return +} + +// Scan implements the database/sql Scanner interface. +func (dst *Hstore) Scan(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Hstore) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/hstore_array.go b/vendor/github.com/jackc/pgtype/hstore_array.go new file mode 100644 index 000000000..47b4b3fff --- /dev/null +++ b/vendor/github.com/jackc/pgtype/hstore_array.go @@ -0,0 +1,489 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type HstoreArray struct { + Elements []Hstore + Dimensions []ArrayDimension + Status Status +} + +func (dst *HstoreArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []map[string]string: + if value == nil { + *dst = HstoreArray{Status: Null} + } else if len(value) == 0 { + *dst = HstoreArray{Status: Present} + } else { + elements := make([]Hstore, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = HstoreArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Hstore: + if value == nil { + *dst = HstoreArray{Status: Null} + } else if len(value) == 0 { + *dst = HstoreArray{Status: Present} + } else { + *dst = HstoreArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = HstoreArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for HstoreArray", src) + } + if elementsLength == 0 { + *dst = HstoreArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to HstoreArray", src) + } + + *dst = HstoreArray{ + Elements: make([]Hstore, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Hstore, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to HstoreArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *HstoreArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to HstoreArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in HstoreArray", err) + } + index++ + + return index, nil +} + +func (dst HstoreArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *HstoreArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]map[string]string: + *v = make([]map[string]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *HstoreArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from HstoreArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from HstoreArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Hstore + + if len(uta.Elements) > 0 { + elements = make([]Hstore, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Hstore + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = HstoreArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Hstore, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("hstore"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "hstore") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *HstoreArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src HstoreArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/inet.go b/vendor/github.com/jackc/pgtype/inet.go new file mode 100644 index 000000000..1645334e3 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/inet.go @@ -0,0 +1,250 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + "net" +) + +// Network address family is dependent on server socket.h value for AF_INET. +// In practice, all platforms appear to have the same value. See +// src/include/utils/inet.h for more information. +const ( + defaultAFInet = 2 + defaultAFInet6 = 3 +) + +// Inet represents both inet and cidr PostgreSQL types. +type Inet struct { + IPNet *net.IPNet + Status Status +} + +func (dst *Inet) Set(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case net.IPNet: + *dst = Inet{IPNet: &value, Status: Present} + case net.IP: + if len(value) == 0 { + *dst = Inet{Status: Null} + } else { + bitCount := len(value) * 8 + mask := net.CIDRMask(bitCount, bitCount) + *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present} + } + case string: + ip, ipnet, err := net.ParseCIDR(value) + if err != nil { + return err + } + ipnet.IP = ip + *dst = Inet{IPNet: ipnet, Status: Present} + case *net.IPNet: + if value == nil { + *dst = Inet{Status: Null} + } else { + return dst.Set(*value) + } + case *net.IP: + if value == nil { + *dst = Inet{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Inet{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Inet", value) + } + + return nil +} + +func (dst Inet) Get() interface{} { + switch dst.Status { + case Present: + return dst.IPNet + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Inet) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *net.IPNet: + *v = net.IPNet{ + IP: make(net.IP, len(src.IPNet.IP)), + Mask: make(net.IPMask, len(src.IPNet.Mask)), + } + copy(v.IP, src.IPNet.IP) + copy(v.Mask, src.IPNet.Mask) + return nil + case *net.IP: + if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = make(net.IP, len(src.IPNet.IP)) + copy(*v, src.IPNet.IP) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + var ipnet *net.IPNet + var err error + + if ip := net.ParseIP(string(src)); ip != nil { + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + bitCount := len(ip) * 8 + mask := net.CIDRMask(bitCount, bitCount) + ipnet = &net.IPNet{Mask: mask, IP: ip} + } else { + ip, ipnet, err = net.ParseCIDR(string(src)) + if err != nil { + return err + } + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + ones, _ := ipnet.Mask.Size() + *ipnet = net.IPNet{IP: ip, Mask: net.CIDRMask(ones, len(ip)*8)} + } + + *dst = Inet{IPNet: ipnet, Status: Present} + return nil +} + +func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + if len(src) != 8 && len(src) != 20 { + return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) + } + + // ignore family + bits := src[1] + // ignore is_cidr + addressLength := src[3] + + var ipnet net.IPNet + ipnet.IP = make(net.IP, int(addressLength)) + copy(ipnet.IP, src[4:]) + if ipv4 := ipnet.IP.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } + ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8) + + *dst = Inet{IPNet: &ipnet, Status: Present} + + return nil +} + +func (src Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.IPNet.String()...), nil +} + +// EncodeBinary encodes src into w. +func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var family byte + switch len(src.IPNet.IP) { + case net.IPv4len: + family = defaultAFInet + case net.IPv6len: + family = defaultAFInet6 + default: + return nil, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) + } + + buf = append(buf, family) + + ones, _ := src.IPNet.Mask.Size() + buf = append(buf, byte(ones)) + + // is_cidr is ignored on server + buf = append(buf, 0) + + buf = append(buf, byte(len(src.IPNet.IP))) + + return append(buf, src.IPNet.IP...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Inet) Scan(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Inet) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/inet_array.go b/vendor/github.com/jackc/pgtype/inet_array.go new file mode 100644 index 000000000..2460a1c4d --- /dev/null +++ b/vendor/github.com/jackc/pgtype/inet_array.go @@ -0,0 +1,546 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "net" + "reflect" + + "github.com/jackc/pgio" +) + +type InetArray struct { + Elements []Inet + Dimensions []ArrayDimension + Status Status +} + +func (dst *InetArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = InetArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []*net.IPNet: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []net.IP: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*net.IP: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Inet: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + *dst = InetArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = InetArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for InetArray", src) + } + if elementsLength == 0 { + *dst = InetArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to InetArray", src) + } + + *dst = InetArray{ + Elements: make([]Inet, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Inet, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to InetArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *InetArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to InetArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in InetArray", err) + } + index++ + + return index, nil +} + +func (dst InetArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *InetArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]*net.IPNet: + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]net.IP: + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.IP: + *v = make([]*net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *InetArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from InetArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from InetArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = InetArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Inet + + if len(uta.Elements) > 0 { + elements = make([]Inet, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Inet + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = InetArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = InetArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Inet, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("inet"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "inet") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *InetArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src InetArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/int2.go b/vendor/github.com/jackc/pgtype/int2.go new file mode 100644 index 000000000..3eb5aeb55 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/int2.go @@ -0,0 +1,304 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Int2 struct { + Int int16 + Status Status +} + +func (dst *Int2) Set(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case int8: + *dst = Int2{Int: int16(value), Status: Present} + case uint8: + *dst = Int2{Int: int16(value), Status: Present} + case int16: + *dst = Int2{Int: int16(value), Status: Present} + case uint16: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case int32: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case uint32: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case int64: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case uint64: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case int: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case uint: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 16) + if err != nil { + return err + } + *dst = Int2{Int: int16(num), Status: Present} + case float32: + if value > math.MaxInt16 { + return fmt.Errorf("%f is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case float64: + if value > math.MaxInt16 { + return fmt.Errorf("%f is greater than maximum value for Int2", value) + } + *dst = Int2{Int: int16(value), Status: Present} + case *int8: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + case *float64: + if value == nil { + *dst = Int2{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (dst Int2) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int2) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) +} + +func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + n, err := strconv.ParseInt(string(src), 10, 16) + if err != nil { + return err + } + + *dst = Int2{Int: int16(n), Status: Present} + return nil +} + +func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + n := int16(binary.BigEndian.Uint16(src)) + *dst = Int2{Int: n, Status: Present} + return nil +} + +func (src Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil +} + +func (src Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return pgio.AppendInt16(buf, src.Int), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + if src > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + *dst = Int2{Int: int16(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src Int2) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} diff --git a/vendor/github.com/jackc/pgtype/int2_array.go b/vendor/github.com/jackc/pgtype/int2_array.go new file mode 100644 index 000000000..a51338450 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/int2_array.go @@ -0,0 +1,909 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type Int2Array struct { + Elements []Int2 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Int2Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int2Array{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []int16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint16: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint32: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint64: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Int2: + if value == nil { + *dst = Int2Array{Status: Null} + } else if len(value) == 0 { + *dst = Int2Array{Status: Present} + } else { + *dst = Int2Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Int2Array{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for Int2Array", src) + } + if elementsLength == 0 { + *dst = Int2Array{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2Array", src) + } + + *dst = Int2Array{ + Elements: make([]Int2, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int2, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to Int2Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *Int2Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to Int2Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in Int2Array", err) + } + index++ + + return index, nil +} + +func (dst Int2Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int2Array) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *Int2Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from Int2Array") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from Int2Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2Array{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Int2 + + if len(uta.Elements) > 0 { + elements = make([]Int2, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int2 + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int2Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int2, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("int2"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "int2") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int2Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/int4.go b/vendor/github.com/jackc/pgtype/int4.go new file mode 100644 index 000000000..22b48e5e5 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/int4.go @@ -0,0 +1,312 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Int4 struct { + Int int32 + Status Status +} + +func (dst *Int4) Set(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case int8: + *dst = Int4{Int: int32(value), Status: Present} + case uint8: + *dst = Int4{Int: int32(value), Status: Present} + case int16: + *dst = Int4{Int: int32(value), Status: Present} + case uint16: + *dst = Int4{Int: int32(value), Status: Present} + case int32: + *dst = Int4{Int: int32(value), Status: Present} + case uint32: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case int64: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case uint64: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case int: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case uint: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return err + } + *dst = Int4{Int: int32(num), Status: Present} + case float32: + if value > math.MaxInt32 { + return fmt.Errorf("%f is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case float64: + if value > math.MaxInt32 { + return fmt.Errorf("%f is greater than maximum value for Int4", value) + } + *dst = Int4{Int: int32(value), Status: Present} + case *int8: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + case *float64: + if value == nil { + *dst = Int4{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int4", value) + } + + return nil +} + +func (dst Int4) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int4) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) +} + +func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + n, err := strconv.ParseInt(string(src), 10, 32) + if err != nil { + return err + } + + *dst = Int4{Int: int32(n), Status: Present} + return nil +} + +func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + n := int32(binary.BigEndian.Uint32(src)) + *dst = Int4{Int: n, Status: Present} + return nil +} + +func (src Int4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil +} + +func (src Int4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return pgio.AppendInt32(buf, src.Int), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + if src > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + *dst = Int4{Int: int32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src Int4) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *Int4) UnmarshalJSON(b []byte) error { + var n *int32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int4{Status: Null} + } else { + *dst = Int4{Int: *n, Status: Present} + } + + return nil +} diff --git a/vendor/github.com/jackc/pgtype/int4_array.go b/vendor/github.com/jackc/pgtype/int4_array.go new file mode 100644 index 000000000..de26236fd --- /dev/null +++ b/vendor/github.com/jackc/pgtype/int4_array.go @@ -0,0 +1,909 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type Int4Array struct { + Elements []Int4 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Int4Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int4Array{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []int16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint16: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint64: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Int4: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + *dst = Int4Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Int4Array{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for Int4Array", src) + } + if elementsLength == 0 { + *dst = Int4Array{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int4Array", src) + } + + *dst = Int4Array{ + Elements: make([]Int4, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int4, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to Int4Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *Int4Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to Int4Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in Int4Array", err) + } + index++ + + return index, nil +} + +func (dst Int4Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int4Array) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *Int4Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from Int4Array") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from Int4Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4Array{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Int4 + + if len(uta.Elements) > 0 { + elements = make([]Int4, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int4 + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int4Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int4, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Int4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("int4"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "int4") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/int4range.go b/vendor/github.com/jackc/pgtype/int4range.go new file mode 100644 index 000000000..c7f51fa6a --- /dev/null +++ b/vendor/github.com/jackc/pgtype/int4range.go @@ -0,0 +1,267 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Int4range struct { + Lower Int4 + Upper Int4 + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Int4range) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + switch value := src.(type) { + case Int4range: + *dst = value + case *Int4range: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Int4range", src) + } + + return nil +} + +func (dst Int4range) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int4range) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Int4range{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Int4range{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4range) Scan(src interface{}) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4range) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/int8.go b/vendor/github.com/jackc/pgtype/int8.go new file mode 100644 index 000000000..0e0899795 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/int8.go @@ -0,0 +1,298 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +type Int8 struct { + Int int64 + Status Status +} + +func (dst *Int8) Set(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case int8: + *dst = Int8{Int: int64(value), Status: Present} + case uint8: + *dst = Int8{Int: int64(value), Status: Present} + case int16: + *dst = Int8{Int: int64(value), Status: Present} + case uint16: + *dst = Int8{Int: int64(value), Status: Present} + case int32: + *dst = Int8{Int: int64(value), Status: Present} + case uint32: + *dst = Int8{Int: int64(value), Status: Present} + case int64: + *dst = Int8{Int: int64(value), Status: Present} + case uint64: + if value > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Status: Present} + case int: + if int64(value) < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + if int64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Status: Present} + case uint: + if uint64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + *dst = Int8{Int: num, Status: Present} + case float32: + if value > math.MaxInt64 { + return fmt.Errorf("%f is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Status: Present} + case float64: + if value > math.MaxInt64 { + return fmt.Errorf("%f is greater than maximum value for Int8", value) + } + *dst = Int8{Int: int64(value), Status: Present} + case *int8: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + case *float64: + if value == nil { + *dst = Int8{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (dst Int8) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int8) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) +} + +func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + n, err := strconv.ParseInt(string(src), 10, 64) + if err != nil { + return err + } + + *dst = Int8{Int: n, Status: Present} + return nil +} + +func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + n := int64(binary.BigEndian.Uint64(src)) + + *dst = Int8{Int: n, Status: Present} + return nil +} + +func (src Int8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, strconv.FormatInt(src.Int, 10)...), nil +} + +func (src Int8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return pgio.AppendInt64(buf, src.Int), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + *dst = Int8{Int: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src Int8) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(src.Int, 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *Int8) UnmarshalJSON(b []byte) error { + var n *int64 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Int8{Status: Null} + } else { + *dst = Int8{Int: *n, Status: Present} + } + + return nil +} diff --git a/vendor/github.com/jackc/pgtype/int8_array.go b/vendor/github.com/jackc/pgtype/int8_array.go new file mode 100644 index 000000000..e405b326d --- /dev/null +++ b/vendor/github.com/jackc/pgtype/int8_array.go @@ -0,0 +1,909 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type Int8Array struct { + Elements []Int8 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Int8Array) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int8Array{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []int16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint16: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint32: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Int8: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + *dst = Int8Array{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = Int8Array{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for Int8Array", src) + } + if elementsLength == 0 { + *dst = Int8Array{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8Array", src) + } + + *dst = Int8Array{ + Elements: make([]Int8, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Int8, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to Int8Array, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *Int8Array) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to Int8Array") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in Int8Array", err) + } + index++ + + return index, nil +} + +func (dst Int8Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int8Array) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]int16: + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int16: + *v = make([]*int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint16: + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint16: + *v = make([]*uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int32: + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int32: + *v = make([]*int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint32: + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint32: + *v = make([]*uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int: + *v = make([]int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int: + *v = make([]*int, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint: + *v = make([]uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint: + *v = make([]*uint, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *Int8Array) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from Int8Array") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from Int8Array") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8Array{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Int8 + + if len(uta.Elements) > 0 { + elements = make([]Int8, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int8 + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int8Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int8, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = Int8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("int8"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "int8") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8Array) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/int8range.go b/vendor/github.com/jackc/pgtype/int8range.go new file mode 100644 index 000000000..71369373f --- /dev/null +++ b/vendor/github.com/jackc/pgtype/int8range.go @@ -0,0 +1,267 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Int8range struct { + Lower Int8 + Upper Int8 + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Int8range) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + switch value := src.(type) { + case Int8range: + *dst = value + case *Int8range: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Int8range", src) + } + + return nil +} + +func (dst Int8range) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int8range) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Int8range{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Int8range{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8range) Scan(src interface{}) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8range) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/interval.go b/vendor/github.com/jackc/pgtype/interval.go new file mode 100644 index 000000000..b01fbb7cb --- /dev/null +++ b/vendor/github.com/jackc/pgtype/interval.go @@ -0,0 +1,257 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + "time" + + "github.com/jackc/pgio" +) + +const ( + microsecondsPerSecond = 1000000 + microsecondsPerMinute = 60 * microsecondsPerSecond + microsecondsPerHour = 60 * microsecondsPerMinute + microsecondsPerDay = 24 * microsecondsPerHour + microsecondsPerMonth = 30 * microsecondsPerDay +) + +type Interval struct { + Microseconds int64 + Days int32 + Months int32 + Status Status +} + +func (dst *Interval) Set(src interface{}) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case time.Duration: + *dst = Interval{Microseconds: int64(value) / 1000, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Interval", value) + } + + return nil +} + +func (dst Interval) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Interval) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Duration: + us := int64(src.Months)*microsecondsPerMonth + int64(src.Days)*microsecondsPerDay + src.Microseconds + *v = time.Duration(us) * time.Microsecond + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + var microseconds int64 + var days int32 + var months int32 + + parts := strings.Split(string(src), " ") + + for i := 0; i < len(parts)-1; i += 2 { + scalar, err := strconv.ParseInt(parts[i], 10, 64) + if err != nil { + return fmt.Errorf("bad interval format") + } + + switch parts[i+1] { + case "year", "years": + months += int32(scalar * 12) + case "mon", "mons": + months += int32(scalar) + case "day", "days": + days = int32(scalar) + } + } + + if len(parts)%2 == 1 { + timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) + if len(timeParts) != 3 { + return fmt.Errorf("bad interval format") + } + + var negative bool + if timeParts[0][0] == '-' { + negative = true + timeParts[0] = timeParts[0][1:] + } + + hours, err := strconv.ParseInt(timeParts[0], 10, 64) + if err != nil { + return fmt.Errorf("bad interval hour format: %s", timeParts[0]) + } + + minutes, err := strconv.ParseInt(timeParts[1], 10, 64) + if err != nil { + return fmt.Errorf("bad interval minute format: %s", timeParts[1]) + } + + secondParts := strings.SplitN(timeParts[2], ".", 2) + + seconds, err := strconv.ParseInt(secondParts[0], 10, 64) + if err != nil { + return fmt.Errorf("bad interval second format: %s", secondParts[0]) + } + + var uSeconds int64 + if len(secondParts) == 2 { + uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) + if err != nil { + return fmt.Errorf("bad interval decimal format: %s", secondParts[1]) + } + + for i := 0; i < 6-len(secondParts[1]); i++ { + uSeconds *= 10 + } + } + + microseconds = hours * microsecondsPerHour + microseconds += minutes * microsecondsPerMinute + microseconds += seconds * microsecondsPerSecond + microseconds += uSeconds + + if negative { + microseconds = -microseconds + } + } + + *dst = Interval{Months: months, Days: days, Microseconds: microseconds, Status: Present} + return nil +} + +func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) + } + + microseconds := int64(binary.BigEndian.Uint64(src)) + days := int32(binary.BigEndian.Uint32(src[8:])) + months := int32(binary.BigEndian.Uint32(src[12:])) + + *dst = Interval{Microseconds: microseconds, Days: days, Months: months, Status: Present} + return nil +} + +func (src Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if src.Months != 0 { + buf = append(buf, strconv.FormatInt(int64(src.Months), 10)...) + buf = append(buf, " mon "...) + } + + if src.Days != 0 { + buf = append(buf, strconv.FormatInt(int64(src.Days), 10)...) + buf = append(buf, " day "...) + } + + absMicroseconds := src.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + buf = append(buf, '-') + } + + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond + microseconds := absMicroseconds % microsecondsPerSecond + + timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) + return append(buf, timeStr...), nil +} + +// EncodeBinary encodes src into w. +func (src Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt64(buf, src.Microseconds) + buf = pgio.AppendInt32(buf, src.Days) + return pgio.AppendInt32(buf, src.Months), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Interval) Scan(src interface{}) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Interval) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/json.go b/vendor/github.com/jackc/pgtype/json.go new file mode 100644 index 000000000..32bef5e76 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/json.go @@ -0,0 +1,205 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "fmt" +) + +type JSON struct { + Bytes []byte + Status Status +} + +func (dst *JSON) Set(src interface{}) error { + if src == nil { + *dst = JSON{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case string: + *dst = JSON{Bytes: []byte(value), Status: Present} + case *string: + if value == nil { + *dst = JSON{Status: Null} + } else { + *dst = JSON{Bytes: []byte(*value), Status: Present} + } + case []byte: + if value == nil { + *dst = JSON{Status: Null} + } else { + *dst = JSON{Bytes: value, Status: Present} + } + // Encode* methods are defined on *JSON. If JSON is passed directly then the + // struct itself would be encoded instead of Bytes. This is clearly a footgun + // so detect and return an error. See https://github.com/jackc/pgx/issues/350. + case JSON: + return errors.New("use pointer to pgtype.JSON instead of value") + // Same as above but for JSONB (because they share implementation) + case JSONB: + return errors.New("use pointer to pgtype.JSONB instead of value") + + default: + buf, err := json.Marshal(value) + if err != nil { + return err + } + *dst = JSON{Bytes: buf, Status: Present} + } + + return nil +} + +func (dst JSON) Get() interface{} { + switch dst.Status { + case Present: + var i interface{} + err := json.Unmarshal(dst.Bytes, &i) + if err != nil { + return dst + } + return i + case Null: + return nil + default: + return dst.Status + } +} + +func (src *JSON) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *string: + if src.Status == Present { + *v = string(src.Bytes) + } else { + return fmt.Errorf("cannot assign non-present status to %T", dst) + } + case **string: + if src.Status == Present { + s := string(src.Bytes) + *v = &s + return nil + } else { + *v = nil + return nil + } + case *[]byte: + if src.Status != Present { + *v = nil + } else { + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + } + default: + data := src.Bytes + if data == nil || src.Status != Present { + data = []byte("null") + } + + return json.Unmarshal(data, dst) + } + + return nil +} + +func (JSON) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSON{Status: Null} + return nil + } + + *dst = JSON{Bytes: src, Status: Present} + return nil +} + +func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) +} + +func (JSON) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.Bytes...), nil +} + +func (src JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *JSON) Scan(src interface{}) error { + if src == nil { + *dst = JSON{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src JSON) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src JSON) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *JSON) UnmarshalJSON(b []byte) error { + if b == nil || string(b) == "null" { + *dst = JSON{Status: Null} + } else { + *dst = JSON{Bytes: b, Status: Present} + } + return nil + +} diff --git a/vendor/github.com/jackc/pgtype/jsonb.go b/vendor/github.com/jackc/pgtype/jsonb.go new file mode 100644 index 000000000..c9dafc939 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/jsonb.go @@ -0,0 +1,85 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" +) + +type JSONB JSON + +func (dst *JSONB) Set(src interface{}) error { + return (*JSON)(dst).Set(src) +} + +func (dst JSONB) Get() interface{} { + return (JSON)(dst).Get() +} + +func (src *JSONB) AssignTo(dst interface{}) error { + return (*JSON)(src).AssignTo(dst) +} + +func (JSONB) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *JSONB) DecodeText(ci *ConnInfo, src []byte) error { + return (*JSON)(dst).DecodeText(ci, src) +} + +func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSONB{Status: Null} + return nil + } + + if len(src) == 0 { + return fmt.Errorf("jsonb too short") + } + + if src[0] != 1 { + return fmt.Errorf("unknown jsonb version number %d", src[0]) + } + + *dst = JSONB{Bytes: src[1:], Status: Present} + return nil + +} + +func (JSONB) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (JSON)(src).EncodeText(ci, buf) +} + +func (src JSONB) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, 1) + return append(buf, src.Bytes...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *JSONB) Scan(src interface{}) error { + return (*JSON)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src JSONB) Value() (driver.Value, error) { + return (JSON)(src).Value() +} + +func (src JSONB) MarshalJSON() ([]byte, error) { + return (JSON)(src).MarshalJSON() +} + +func (dst *JSONB) UnmarshalJSON(b []byte) error { + return (*JSON)(dst).UnmarshalJSON(b) +} diff --git a/vendor/github.com/jackc/pgtype/jsonb_array.go b/vendor/github.com/jackc/pgtype/jsonb_array.go new file mode 100644 index 000000000..c4b7cd3d8 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/jsonb_array.go @@ -0,0 +1,517 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type JSONBArray struct { + Elements []JSONB + Dimensions []ArrayDimension + Status Status +} + +func (dst *JSONBArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = JSONBArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = JSONBArray{Status: Null} + } else if len(value) == 0 { + *dst = JSONBArray{Status: Present} + } else { + elements := make([]JSONB, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = JSONBArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case [][]byte: + if value == nil { + *dst = JSONBArray{Status: Null} + } else if len(value) == 0 { + *dst = JSONBArray{Status: Present} + } else { + elements := make([]JSONB, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = JSONBArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []JSONB: + if value == nil { + *dst = JSONBArray{Status: Null} + } else if len(value) == 0 { + *dst = JSONBArray{Status: Present} + } else { + *dst = JSONBArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = JSONBArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for JSONBArray", src) + } + if elementsLength == 0 { + *dst = JSONBArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to JSONBArray", src) + } + + *dst = JSONBArray{ + Elements: make([]JSONB, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]JSONB, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to JSONBArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *JSONBArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to JSONBArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in JSONBArray", err) + } + index++ + + return index, nil +} + +func (dst JSONBArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *JSONBArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *JSONBArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from JSONBArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from JSONBArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *JSONBArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSONBArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []JSONB + + if len(uta.Elements) > 0 { + elements = make([]JSONB, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem JSONB + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = JSONBArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *JSONBArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = JSONBArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = JSONBArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]JSONB, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = JSONBArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src JSONBArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src JSONBArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("jsonb"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "jsonb") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *JSONBArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src JSONBArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/line.go b/vendor/github.com/jackc/pgtype/line.go new file mode 100644 index 000000000..3564b1748 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/line.go @@ -0,0 +1,148 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Line struct { + A, B, C float64 + Status Status +} + +func (dst *Line) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Line", src) +} + +func (dst Line) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Line) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Line: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) + if len(parts) < 3 { + return fmt.Errorf("invalid format for line") + } + + a, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + b, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + c, err := strconv.ParseFloat(parts[2], 64) + if err != nil { + return err + } + + *dst = Line{A: a, B: b, C: c, Status: Present} + return nil +} + +func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Line: %v", len(src)) + } + + a := binary.BigEndian.Uint64(src) + b := binary.BigEndian.Uint64(src[8:]) + c := binary.BigEndian.Uint64(src[16:]) + + *dst = Line{ + A: math.Float64frombits(a), + B: math.Float64frombits(b), + C: math.Float64frombits(c), + Status: Present, + } + return nil +} + +func (src Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, fmt.Sprintf(`{%s,%s,%s}`, + strconv.FormatFloat(src.A, 'f', -1, 64), + strconv.FormatFloat(src.B, 'f', -1, 64), + strconv.FormatFloat(src.C, 'f', -1, 64), + )...) + + return buf, nil +} + +func (src Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.A)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.B)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.C)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Line) Scan(src interface{}) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Line) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/lseg.go b/vendor/github.com/jackc/pgtype/lseg.go new file mode 100644 index 000000000..5c4babb69 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/lseg.go @@ -0,0 +1,165 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Lseg struct { + P [2]Vec2 + Status Status +} + +func (dst *Lseg) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Lseg", src) +} + +func (dst Lseg) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Lseg) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + if len(src) < 11 { + return fmt.Errorf("invalid length for Lseg: %v", len(src)) + } + + str := string(src[2:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-2] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + return nil +} + +func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for Lseg: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + *dst = Lseg{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Status: Present, + } + return nil +} + +func (src Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(src.P[0].X, 'f', -1, 64), + strconv.FormatFloat(src.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(src.P[1].X, 'f', -1, 64), + strconv.FormatFloat(src.P[1].Y, 'f', -1, 64), + )...) + + return buf, nil +} + +func (src Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Lseg) Scan(src interface{}) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Lseg) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/macaddr.go b/vendor/github.com/jackc/pgtype/macaddr.go new file mode 100644 index 000000000..1d3cfe7b1 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/macaddr.go @@ -0,0 +1,173 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + "net" +) + +type Macaddr struct { + Addr net.HardwareAddr + Status Status +} + +func (dst *Macaddr) Set(src interface{}) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case net.HardwareAddr: + addr := make(net.HardwareAddr, len(value)) + copy(addr, value) + *dst = Macaddr{Addr: addr, Status: Present} + case string: + addr, err := net.ParseMAC(value) + if err != nil { + return err + } + *dst = Macaddr{Addr: addr, Status: Present} + case *net.HardwareAddr: + if value == nil { + *dst = Macaddr{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Macaddr{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Macaddr", value) + } + + return nil +} + +func (dst Macaddr) Get() interface{} { + switch dst.Status { + case Present: + return dst.Addr + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Macaddr) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *net.HardwareAddr: + *v = make(net.HardwareAddr, len(src.Addr)) + copy(*v, src.Addr) + return nil + case *string: + *v = src.Addr.String() + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + addr, err := net.ParseMAC(string(src)) + if err != nil { + return err + } + + *dst = Macaddr{Addr: addr, Status: Present} + return nil +} + +func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + if len(src) != 6 { + return fmt.Errorf("Received an invalid size for a macaddr: %d", len(src)) + } + + addr := make(net.HardwareAddr, 6) + copy(addr, src) + + *dst = Macaddr{Addr: addr, Status: Present} + + return nil +} + +func (src Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.Addr.String()...), nil +} + +// EncodeBinary encodes src into w. +func (src Macaddr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.Addr...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Macaddr) Scan(src interface{}) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Macaddr) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/macaddr_array.go b/vendor/github.com/jackc/pgtype/macaddr_array.go new file mode 100644 index 000000000..bdb1f2034 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/macaddr_array.go @@ -0,0 +1,518 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "net" + "reflect" + + "github.com/jackc/pgio" +) + +type MacaddrArray struct { + Elements []Macaddr + Dimensions []ArrayDimension + Status Status +} + +func (dst *MacaddrArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = MacaddrArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []net.HardwareAddr: + if value == nil { + *dst = MacaddrArray{Status: Null} + } else if len(value) == 0 { + *dst = MacaddrArray{Status: Present} + } else { + elements := make([]Macaddr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = MacaddrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*net.HardwareAddr: + if value == nil { + *dst = MacaddrArray{Status: Null} + } else if len(value) == 0 { + *dst = MacaddrArray{Status: Present} + } else { + elements := make([]Macaddr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = MacaddrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Macaddr: + if value == nil { + *dst = MacaddrArray{Status: Null} + } else if len(value) == 0 { + *dst = MacaddrArray{Status: Present} + } else { + *dst = MacaddrArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = MacaddrArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for MacaddrArray", src) + } + if elementsLength == 0 { + *dst = MacaddrArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to MacaddrArray", src) + } + + *dst = MacaddrArray{ + Elements: make([]Macaddr, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Macaddr, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to MacaddrArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *MacaddrArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to MacaddrArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in MacaddrArray", err) + } + index++ + + return index, nil +} + +func (dst MacaddrArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *MacaddrArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]net.HardwareAddr: + *v = make([]net.HardwareAddr, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*net.HardwareAddr: + *v = make([]*net.HardwareAddr, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *MacaddrArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from MacaddrArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from MacaddrArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *MacaddrArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = MacaddrArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Macaddr + + if len(uta.Elements) > 0 { + elements = make([]Macaddr, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Macaddr + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = MacaddrArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *MacaddrArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = MacaddrArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = MacaddrArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Macaddr, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = MacaddrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src MacaddrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src MacaddrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("macaddr"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "macaddr") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *MacaddrArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src MacaddrArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/name.go b/vendor/github.com/jackc/pgtype/name.go new file mode 100644 index 000000000..7ce8d25e9 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/name.go @@ -0,0 +1,58 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// Name is a type used for PostgreSQL's special 63-byte +// name data type, used for identifiers like table names. +// The pg_class.relname column is a good example of where the +// name data type is used. +// +// Note that the underlying Go data type of pgx.Name is string, +// so there is no way to enforce the 63-byte length. Inputting +// a longer name into PostgreSQL will result in silent truncation +// to 63 bytes. +// +// Also, if you have custom-compiled PostgreSQL and set +// NAMEDATALEN to a different value, obviously that number of +// bytes applies, rather than the default 63. +type Name Text + +func (dst *Name) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst Name) Get() interface{} { + return (Text)(dst).Get() +} + +func (src *Name) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Name) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (src Name) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) +} + +func (src Name) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Name) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Name) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/vendor/github.com/jackc/pgtype/numeric.go b/vendor/github.com/jackc/pgtype/numeric.go new file mode 100644 index 000000000..a7efa704c --- /dev/null +++ b/vendor/github.com/jackc/pgtype/numeric.go @@ -0,0 +1,716 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "math/big" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 +const nbase = 10000 + +const ( + pgNumericNaN = 0x00000000c0000000 + pgNumericNaNSign = 0xc000 +) + +var big0 *big.Int = big.NewInt(0) +var big1 *big.Int = big.NewInt(1) +var big10 *big.Int = big.NewInt(10) +var big100 *big.Int = big.NewInt(100) +var big1000 *big.Int = big.NewInt(1000) + +var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) +var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) +var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) +var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) +var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) +var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) +var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) +var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) +var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) +var bigMinInt *big.Int = big.NewInt(int64(minInt)) + +var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) +var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) +var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) +var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) +var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) + +var bigNBase *big.Int = big.NewInt(nbase) +var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) +var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) +var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) + +type Numeric struct { + Int *big.Int + Exp int32 + Status Status + NaN bool +} + +func (dst *Numeric) Set(src interface{}) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case float32: + if math.IsNaN(float64(value)) { + *dst = Numeric{Status: Present, NaN: true} + return nil + } + num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + case float64: + if math.IsNaN(value) { + *dst = Numeric{Status: Present, NaN: true} + return nil + } + num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + case int8: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint8: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int16: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint16: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int32: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint32: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int64: + *dst = Numeric{Int: big.NewInt(value), Status: Present} + case uint64: + *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present} + case int: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint: + *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present} + case string: + num, exp, err := parseNumericString(value) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + case *float64: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *float32: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *int8: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *uint8: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *int16: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *uint16: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *int32: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *uint32: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *int64: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *uint64: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *int: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *uint: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + case *string: + if value == nil { + *dst = Numeric{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + return nil +} + +func (dst Numeric) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Numeric) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *float32: + f, err := src.toFloat64() + if err != nil { + return err + } + return float64AssignTo(f, src.Status, dst) + case *float64: + f, err := src.toFloat64() + if err != nil { + return err + } + return float64AssignTo(f, src.Status, dst) + case *int: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int(normalizedInt.Int64()) + case *int8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt8) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt8) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int8(normalizedInt.Int64()) + case *int16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt16) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt16) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int16(normalizedInt.Int64()) + case *int32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt32) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt32) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int32(normalizedInt.Int64()) + case *int64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt64) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt64) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Int64() + case *uint: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint(normalizedInt.Uint64()) + case *uint8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint8) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint8(normalizedInt.Uint64()) + case *uint16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint16) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint16(normalizedInt.Uint64()) + case *uint32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint32) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint32(normalizedInt.Uint64()) + case *uint64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint64) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Uint64() + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return nil +} + +func (dst *Numeric) toBigInt() (*big.Int, error) { + if dst.Exp == 0 { + return dst.Int, nil + } + + num := &big.Int{} + num.Set(dst.Int) + if dst.Exp > 0 { + mul := &big.Int{} + mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil) + num.Mul(num, mul) + return num, nil + } + + div := &big.Int{} + div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil) + remainder := &big.Int{} + num.DivMod(num, div, remainder) + if remainder.Cmp(big0) != 0 { + return nil, fmt.Errorf("cannot convert %v to integer", dst) + } + return num, nil +} + +func (src *Numeric) toFloat64() (float64, error) { + if src.NaN { + return math.NaN(), nil + } + + buf := make([]byte, 0, 32) + + buf = append(buf, src.Int.String()...) + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) + + f, err := strconv.ParseFloat(string(buf), 64) + if err != nil { + return 0, err + } + return f, nil +} + +func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + if string(src) == "NaN" { + *dst = Numeric{Status: Present, NaN: true} + return nil + } + + num, exp, err := parseNumericString(string(src)) + if err != nil { + return err + } + + *dst = Numeric{Int: num, Exp: exp, Status: Present} + return nil +} + +func parseNumericString(str string) (n *big.Int, exp int32, err error) { + parts := strings.SplitN(str, ".", 2) + digits := strings.Join(parts, "") + + if len(parts) > 1 { + exp = int32(-len(parts[1])) + } else { + for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' { + digits = digits[:len(digits)-1] + exp++ + } + } + + accum := &big.Int{} + if _, ok := accum.SetString(digits, 10); !ok { + return nil, 0, fmt.Errorf("%s is not a number", str) + } + + return accum, exp, nil +} + +func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + if len(src) < 8 { + return fmt.Errorf("numeric incomplete %v", src) + } + + rp := 0 + ndigits := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + weight := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + sign := uint16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + dscale := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if sign == pgNumericNaNSign { + *dst = Numeric{Status: Present, NaN: true} + return nil + } + + if ndigits == 0 { + *dst = Numeric{Int: big.NewInt(0), Status: Present} + return nil + } + + if len(src[rp:]) < int(ndigits)*2 { + return fmt.Errorf("numeric incomplete %v", src) + } + + accum := &big.Int{} + + for i := 0; i < int(ndigits+3)/4; i++ { + int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:]) + rp += bytesRead + + if i > 0 { + var mul *big.Int + switch digitsRead { + case 1: + mul = bigNBase + case 2: + mul = bigNBaseX2 + case 3: + mul = bigNBaseX3 + case 4: + mul = bigNBaseX4 + default: + return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) + } + accum.Mul(accum, mul) + } + + accum.Add(accum, big.NewInt(int64accum)) + } + + exp := (int32(weight) - int32(ndigits) + 1) * 4 + + if dscale > 0 { + fracNBaseDigits := ndigits - weight - 1 + fracDecimalDigits := fracNBaseDigits * 4 + + if dscale > fracDecimalDigits { + multCount := int(dscale - fracDecimalDigits) + for i := 0; i < multCount; i++ { + accum.Mul(accum, big10) + exp-- + } + } else if dscale < fracDecimalDigits { + divCount := int(fracDecimalDigits - dscale) + for i := 0; i < divCount; i++ { + accum.Div(accum, big10) + exp++ + } + } + } + + reduced := &big.Int{} + remainder := &big.Int{} + if exp >= 0 { + for { + reduced.DivMod(accum, big10, remainder) + if remainder.Cmp(big0) != 0 { + break + } + accum.Set(reduced) + exp++ + } + } + + if sign != 0 { + accum.Neg(accum) + } + + *dst = Numeric{Int: accum, Exp: exp, Status: Present} + + return nil + +} + +func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { + digits := len(src) / 2 + if digits > 4 { + digits = 4 + } + + rp := 0 + + for i := 0; i < digits; i++ { + if i > 0 { + accum *= nbase + } + accum += int64(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return accum, rp, digits +} + +func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if src.NaN { + buf = append(buf, "NaN"...) + return buf, nil + } + + buf = append(buf, src.Int.String()...) + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) + return buf, nil +} + +func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if src.NaN { + buf = pgio.AppendUint64(buf, pgNumericNaN) + return buf, nil + } + + var sign int16 + if src.Int.Cmp(big0) < 0 { + sign = 16384 + } + + absInt := &big.Int{} + wholePart := &big.Int{} + fracPart := &big.Int{} + remainder := &big.Int{} + absInt.Abs(src.Int) + + // Normalize absInt and exp to where exp is always a multiple of 4. This makes + // converting to 16-bit base 10,000 digits easier. + var exp int32 + switch src.Exp % 4 { + case 1, -3: + exp = src.Exp - 1 + absInt.Mul(absInt, big10) + case 2, -2: + exp = src.Exp - 2 + absInt.Mul(absInt, big100) + case 3, -1: + exp = src.Exp - 3 + absInt.Mul(absInt, big1000) + default: + exp = src.Exp + } + + if exp < 0 { + divisor := &big.Int{} + divisor.Exp(big10, big.NewInt(int64(-exp)), nil) + wholePart.DivMod(absInt, divisor, fracPart) + fracPart.Add(fracPart, divisor) + } else { + wholePart = absInt + } + + var wholeDigits, fracDigits []int16 + + for wholePart.Cmp(big0) != 0 { + wholePart.DivMod(wholePart, bigNBase, remainder) + wholeDigits = append(wholeDigits, int16(remainder.Int64())) + } + + if fracPart.Cmp(big0) != 0 { + for fracPart.Cmp(big1) != 0 { + fracPart.DivMod(fracPart, bigNBase, remainder) + fracDigits = append(fracDigits, int16(remainder.Int64())) + } + } + + buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) + + var weight int16 + if len(wholeDigits) > 0 { + weight = int16(len(wholeDigits) - 1) + if exp > 0 { + weight += int16(exp / 4) + } + } else { + weight = int16(exp/4) - 1 + int16(len(fracDigits)) + } + buf = pgio.AppendInt16(buf, weight) + + buf = pgio.AppendInt16(buf, sign) + + var dscale int16 + if src.Exp < 0 { + dscale = int16(-src.Exp) + } + buf = pgio.AppendInt16(buf, dscale) + + for i := len(wholeDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, wholeDigits[i]) + } + + for i := len(fracDigits) - 1; i >= 0; i-- { + buf = pgio.AppendInt16(buf, fracDigits[i]) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numeric) Scan(src interface{}) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Numeric) Value() (driver.Value, error) { + switch src.Status { + case Present: + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + + return string(buf), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/vendor/github.com/jackc/pgtype/numeric_array.go b/vendor/github.com/jackc/pgtype/numeric_array.go new file mode 100644 index 000000000..31899dec9 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/numeric_array.go @@ -0,0 +1,685 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type NumericArray struct { + Elements []Numeric + Dimensions []ArrayDimension + Status Status +} + +func (dst *NumericArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []float32: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*float32: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []float64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*float64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []int64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*int64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*uint64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Numeric: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + *dst = NumericArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = NumericArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for NumericArray", src) + } + if elementsLength == 0 { + *dst = NumericArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to NumericArray", src) + } + + *dst = NumericArray{ + Elements: make([]Numeric, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Numeric, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to NumericArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *NumericArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to NumericArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in NumericArray", err) + } + index++ + + return index, nil +} + +func (dst NumericArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *NumericArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float32: + *v = make([]*float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*float64: + *v = make([]*float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]int64: + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*int64: + *v = make([]*int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]uint64: + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*uint64: + *v = make([]*uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *NumericArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from NumericArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from NumericArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Numeric + + if len(uta.Elements) > 0 { + elements = make([]Numeric, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Numeric + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = NumericArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = NumericArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Numeric, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = NumericArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("numeric"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "numeric") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *NumericArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src NumericArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/numrange.go b/vendor/github.com/jackc/pgtype/numrange.go new file mode 100644 index 000000000..3d5951a24 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/numrange.go @@ -0,0 +1,267 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Numrange struct { + Lower Numeric + Upper Numeric + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Numrange) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + switch value := src.(type) { + case Numrange: + *dst = value + case *Numrange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Numrange", src) + } + + return nil +} + +func (dst Numrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Numrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Numrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Numrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numrange) Scan(src interface{}) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Numrange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/oid.go b/vendor/github.com/jackc/pgtype/oid.go new file mode 100644 index 000000000..31677e894 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/oid.go @@ -0,0 +1,81 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + + "github.com/jackc/pgio" +) + +// OID (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-oid.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. Because it is +// so frequently required to be in a NOT NULL condition OID cannot be NULL. To +// allow for NULL OIDs use OIDValue. +type OID uint32 + +func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + return fmt.Errorf("cannot decode nil into OID") + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *dst = OID(n) + return nil +} + +func (dst *OID) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + return fmt.Errorf("cannot decode nil into OID") + } + + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + *dst = OID(n) + return nil +} + +func (src OID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return append(buf, strconv.FormatUint(uint64(src), 10)...), nil +} + +func (src OID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return pgio.AppendUint32(buf, uint32(src)), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *OID) Scan(src interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", src) + } + + switch src := src.(type) { + case int64: + *dst = OID(src) + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src OID) Value() (driver.Value, error) { + return int64(src), nil +} diff --git a/vendor/github.com/jackc/pgtype/oid_value.go b/vendor/github.com/jackc/pgtype/oid_value.go new file mode 100644 index 000000000..5dc9136cb --- /dev/null +++ b/vendor/github.com/jackc/pgtype/oid_value.go @@ -0,0 +1,55 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// OIDValue (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-OIDValue.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. +type OIDValue pguint32 + +// Set converts from src to dst. Note that as OIDValue is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *OIDValue) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst OIDValue) Get() interface{} { + return (pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as OIDValue is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *OIDValue) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *OIDValue) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) +} + +func (dst *OIDValue) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) +} + +func (src OIDValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeText(ci, buf) +} + +func (src OIDValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *OIDValue) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src OIDValue) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/vendor/github.com/jackc/pgtype/path.go b/vendor/github.com/jackc/pgtype/path.go new file mode 100644 index 000000000..9f89969e0 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/path.go @@ -0,0 +1,195 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Path struct { + P []Vec2 + Closed bool + Status Status +} + +func (dst *Path) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Path", src) +} + +func (dst Path) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Path) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == '(' + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + *dst = Path{P: points, Closed: closed, Status: Present} + return nil +} + +func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == 1 + pointCount := int(binary.BigEndian.Uint32(src[1:])) + + rp := 5 + + if 5+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + *dst = Path{ + P: points, + Closed: closed, + Status: Present, + } + return nil +} + +func (src Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var startByte, endByte byte + if src.Closed { + startByte = '(' + endByte = ')' + } else { + startByte = '[' + endByte = ']' + } + buf = append(buf, startByte) + + for i, p := range src.P { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) + } + + return append(buf, endByte), nil +} + +func (src Path) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var closeByte byte + if src.Closed { + closeByte = 1 + } + buf = append(buf, closeByte) + + buf = pgio.AppendInt32(buf, int32(len(src.P))) + + for _, p := range src.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Path) Scan(src interface{}) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Path) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/pgtype.go b/vendor/github.com/jackc/pgtype/pgtype.go new file mode 100644 index 000000000..4a6808449 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/pgtype.go @@ -0,0 +1,940 @@ +package pgtype + +import ( + "database/sql" + "encoding/binary" + "errors" + "fmt" + "math" + "net" + "reflect" + "time" +) + +// PostgreSQL oids for common types +const ( + BoolOID = 16 + ByteaOID = 17 + QCharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TIDOID = 27 + XIDOID = 28 + CIDOID = 29 + JSONOID = 114 + PointOID = 600 + LsegOID = 601 + PathOID = 602 + BoxOID = 603 + PolygonOID = 604 + LineOID = 628 + CIDROID = 650 + CIDRArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + CircleOID = 718 + UnknownOID = 705 + MacaddrOID = 829 + InetOID = 869 + BoolArrayOID = 1000 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + ByteaArrayOID = 1001 + BPCharArrayOID = 1014 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 + InetArrayOID = 1041 + BPCharOID = 1042 + VarcharOID = 1043 + DateOID = 1082 + TimeOID = 1083 + TimestampOID = 1114 + TimestampArrayOID = 1115 + DateArrayOID = 1182 + TimestamptzOID = 1184 + TimestamptzArrayOID = 1185 + IntervalOID = 1186 + NumericArrayOID = 1231 + BitOID = 1560 + VarbitOID = 1562 + NumericOID = 1700 + RecordOID = 2249 + UUIDOID = 2950 + UUIDArrayOID = 2951 + JSONBOID = 3802 + JSONBArrayOID = 3807 + DaterangeOID = 3912 + Int4rangeOID = 3904 + NumrangeOID = 3906 + TsrangeOID = 3908 + TsrangeArrayOID = 3909 + TstzrangeOID = 3910 + TstzrangeArrayOID = 3911 + Int8rangeOID = 3926 +) + +type Status byte + +const ( + Undefined Status = iota + Null + Present +) + +type InfinityModifier int8 + +const ( + Infinity InfinityModifier = 1 + None InfinityModifier = 0 + NegativeInfinity InfinityModifier = -Infinity +) + +func (im InfinityModifier) String() string { + switch im { + case None: + return "none" + case Infinity: + return "infinity" + case NegativeInfinity: + return "-infinity" + default: + return "invalid" + } +} + +// PostgreSQL format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + +// Value translates values to and from an internal canonical representation for the type. To actually be usable a type +// that implements Value should also implement some combination of BinaryDecoder, BinaryEncoder, TextDecoder, +// and TextEncoder. +// +// Operations that update a Value (e.g. Set, DecodeText, DecodeBinary) should entirely replace the value. e.g. Internal +// slices should be replaced not resized and reused. This allows Get and AssignTo to return a slice directly rather +// than incur a usually unnecessary copy. +type Value interface { + // Set converts and assigns src to itself. Value takes ownership of src. + Set(src interface{}) error + + // Get returns the simplest representation of Value. Get may return a pointer to an internal value but it must never + // mutate that value. e.g. If Get returns a []byte Value must never change the contents of the []byte. + Get() interface{} + + // AssignTo converts and assigns the Value to dst. AssignTo may a pointer to an internal value but it must never + // mutate that value. e.g. If Get returns a []byte Value must never change the contents of the []byte. + AssignTo(dst interface{}) error +} + +// TypeValue is a Value where instances can represent different PostgreSQL types. This can be useful for +// representing types such as enums, composites, and arrays. +// +// In general, instances of TypeValue should not be used to directly represent a value. It should only be used as an +// encoder and decoder internal to ConnInfo. +type TypeValue interface { + Value + + // NewTypeValue creates a TypeValue including references to internal type information. e.g. the list of members + // in an EnumType. + NewTypeValue() Value + + // TypeName returns the PostgreSQL name of this type. + TypeName() string +} + +// ValueTranscoder is a value that implements the text and binary encoding and decoding interfaces. +type ValueTranscoder interface { + Value + TextEncoder + BinaryEncoder + TextDecoder + BinaryDecoder +} + +// ResultFormatPreferrer allows a type to specify its preferred result format instead of it being inferred from +// whether it is also a BinaryDecoder. +type ResultFormatPreferrer interface { + PreferredResultFormat() int16 +} + +// ParamFormatPreferrer allows a type to specify its preferred param format instead of it being inferred from +// whether it is also a BinaryEncoder. +type ParamFormatPreferrer interface { + PreferredParamFormat() int16 +} + +type BinaryDecoder interface { + // DecodeBinary decodes src into BinaryDecoder. If src is nil then the + // original SQL value is NULL. BinaryDecoder takes ownership of src. The + // caller MUST not use it again. + DecodeBinary(ci *ConnInfo, src []byte) error +} + +type TextDecoder interface { + // DecodeText decodes src into TextDecoder. If src is nil then the original + // SQL value is NULL. TextDecoder takes ownership of src. The caller MUST not + // use it again. + DecodeText(ci *ConnInfo, src []byte) error +} + +// BinaryEncoder is implemented by types that can encode themselves into the +// PostgreSQL binary wire format. +type BinaryEncoder interface { + // EncodeBinary should append the binary format of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of + // EncodeBinary is responsible for writing the correct NULL value or the + // length of the data written. + EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) +} + +// TextEncoder is implemented by types that can encode themselves into the +// PostgreSQL text wire format. +type TextEncoder interface { + // EncodeText should append the text format of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of + // EncodeText is responsible for writing the correct NULL value or the + // length of the data written. + EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) +} + +var errUndefined = errors.New("cannot encode status undefined") +var errBadStatus = errors.New("invalid status") + +type nullAssignmentError struct { + dst interface{} +} + +func (e *nullAssignmentError) Error() string { + return fmt.Sprintf("cannot assign NULL to %T", e.dst) +} + +type DataType struct { + Value Value + + textDecoder TextDecoder + binaryDecoder BinaryDecoder + + Name string + OID uint32 +} + +type ConnInfo struct { + oidToDataType map[uint32]*DataType + nameToDataType map[string]*DataType + reflectTypeToName map[reflect.Type]string + oidToParamFormatCode map[uint32]int16 + oidToResultFormatCode map[uint32]int16 + + reflectTypeToDataType map[reflect.Type]*DataType +} + +func newConnInfo() *ConnInfo { + return &ConnInfo{ + oidToDataType: make(map[uint32]*DataType), + nameToDataType: make(map[string]*DataType), + reflectTypeToName: make(map[reflect.Type]string), + oidToParamFormatCode: make(map[uint32]int16), + oidToResultFormatCode: make(map[uint32]int16), + } +} + +func NewConnInfo() *ConnInfo { + ci := newConnInfo() + + ci.RegisterDataType(DataType{Value: &ACLItemArray{}, Name: "_aclitem", OID: ACLItemArrayOID}) + ci.RegisterDataType(DataType{Value: &BoolArray{}, Name: "_bool", OID: BoolArrayOID}) + ci.RegisterDataType(DataType{Value: &BPCharArray{}, Name: "_bpchar", OID: BPCharArrayOID}) + ci.RegisterDataType(DataType{Value: &ByteaArray{}, Name: "_bytea", OID: ByteaArrayOID}) + ci.RegisterDataType(DataType{Value: &CIDRArray{}, Name: "_cidr", OID: CIDRArrayOID}) + ci.RegisterDataType(DataType{Value: &DateArray{}, Name: "_date", OID: DateArrayOID}) + ci.RegisterDataType(DataType{Value: &Float4Array{}, Name: "_float4", OID: Float4ArrayOID}) + ci.RegisterDataType(DataType{Value: &Float8Array{}, Name: "_float8", OID: Float8ArrayOID}) + ci.RegisterDataType(DataType{Value: &InetArray{}, Name: "_inet", OID: InetArrayOID}) + ci.RegisterDataType(DataType{Value: &Int2Array{}, Name: "_int2", OID: Int2ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int4Array{}, Name: "_int4", OID: Int4ArrayOID}) + ci.RegisterDataType(DataType{Value: &Int8Array{}, Name: "_int8", OID: Int8ArrayOID}) + ci.RegisterDataType(DataType{Value: &NumericArray{}, Name: "_numeric", OID: NumericArrayOID}) + ci.RegisterDataType(DataType{Value: &TextArray{}, Name: "_text", OID: TextArrayOID}) + ci.RegisterDataType(DataType{Value: &TimestampArray{}, Name: "_timestamp", OID: TimestampArrayOID}) + ci.RegisterDataType(DataType{Value: &TimestamptzArray{}, Name: "_timestamptz", OID: TimestamptzArrayOID}) + ci.RegisterDataType(DataType{Value: &UUIDArray{}, Name: "_uuid", OID: UUIDArrayOID}) + ci.RegisterDataType(DataType{Value: &VarcharArray{}, Name: "_varchar", OID: VarcharArrayOID}) + ci.RegisterDataType(DataType{Value: &ACLItem{}, Name: "aclitem", OID: ACLItemOID}) + ci.RegisterDataType(DataType{Value: &Bit{}, Name: "bit", OID: BitOID}) + ci.RegisterDataType(DataType{Value: &Bool{}, Name: "bool", OID: BoolOID}) + ci.RegisterDataType(DataType{Value: &Box{}, Name: "box", OID: BoxOID}) + ci.RegisterDataType(DataType{Value: &BPChar{}, Name: "bpchar", OID: BPCharOID}) + ci.RegisterDataType(DataType{Value: &Bytea{}, Name: "bytea", OID: ByteaOID}) + ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) + ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID}) + ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) + ci.RegisterDataType(DataType{Value: &Circle{}, Name: "circle", OID: CircleOID}) + ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID}) + ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) + ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) + ci.RegisterDataType(DataType{Value: &Float8{}, Name: "float8", OID: Float8OID}) + ci.RegisterDataType(DataType{Value: &Inet{}, Name: "inet", OID: InetOID}) + ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID}) + ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID}) + ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID}) + ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID}) + ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID}) + ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID}) + ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID}) + ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) + ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID}) + ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID}) + ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID}) + ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) + ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID}) + ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID}) + ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID}) + ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID}) + ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID}) + ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID}) + ci.RegisterDataType(DataType{Value: &Polygon{}, Name: "polygon", OID: PolygonOID}) + ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) + ci.RegisterDataType(DataType{Value: &Text{}, Name: "text", OID: TextOID}) + ci.RegisterDataType(DataType{Value: &TID{}, Name: "tid", OID: TIDOID}) + ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID}) + ci.RegisterDataType(DataType{Value: &Timestamp{}, Name: "timestamp", OID: TimestampOID}) + ci.RegisterDataType(DataType{Value: &Timestamptz{}, Name: "timestamptz", OID: TimestamptzOID}) + ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) + ci.RegisterDataType(DataType{Value: &TsrangeArray{}, Name: "_tsrange", OID: TsrangeArrayOID}) + ci.RegisterDataType(DataType{Value: &Tstzrange{}, Name: "tstzrange", OID: TstzrangeOID}) + ci.RegisterDataType(DataType{Value: &TstzrangeArray{}, Name: "_tstzrange", OID: TstzrangeArrayOID}) + ci.RegisterDataType(DataType{Value: &Unknown{}, Name: "unknown", OID: UnknownOID}) + ci.RegisterDataType(DataType{Value: &UUID{}, Name: "uuid", OID: UUIDOID}) + ci.RegisterDataType(DataType{Value: &Varbit{}, Name: "varbit", OID: VarbitOID}) + ci.RegisterDataType(DataType{Value: &Varchar{}, Name: "varchar", OID: VarcharOID}) + ci.RegisterDataType(DataType{Value: &XID{}, Name: "xid", OID: XIDOID}) + + registerDefaultPgTypeVariants := func(name, arrayName string, value interface{}) { + ci.RegisterDefaultPgType(value, name) + valueType := reflect.TypeOf(value) + + ci.RegisterDefaultPgType(reflect.New(valueType).Interface(), name) + + sliceType := reflect.SliceOf(valueType) + ci.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName) + + ci.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName) + } + + // Integer types that directly map to a PostgreSQL type + registerDefaultPgTypeVariants("int2", "_int2", int16(0)) + registerDefaultPgTypeVariants("int4", "_int4", int32(0)) + registerDefaultPgTypeVariants("int8", "_int8", int64(0)) + + // Integer types that do not have a direct match to a PostgreSQL type + registerDefaultPgTypeVariants("int8", "_int8", uint16(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint32(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint64(0)) + registerDefaultPgTypeVariants("int8", "_int8", int(0)) + registerDefaultPgTypeVariants("int8", "_int8", uint(0)) + + registerDefaultPgTypeVariants("float4", "_float4", float32(0)) + registerDefaultPgTypeVariants("float8", "_float8", float64(0)) + + registerDefaultPgTypeVariants("bool", "_bool", false) + registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{}) + registerDefaultPgTypeVariants("text", "_text", "") + registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil)) + + registerDefaultPgTypeVariants("inet", "_inet", net.IP{}) + ci.RegisterDefaultPgType((*net.IPNet)(nil), "cidr") + ci.RegisterDefaultPgType([]*net.IPNet(nil), "_cidr") + + return ci +} + +func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { + for name, oid := range nameOIDs { + var value Value + if t, ok := nameValues[name]; ok { + value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) + } else { + value = &GenericText{} + } + ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid}) + } +} + +func (ci *ConnInfo) RegisterDataType(t DataType) { + t.Value = NewValue(t.Value) + + ci.oidToDataType[t.OID] = &t + ci.nameToDataType[t.Name] = &t + + { + var formatCode int16 + if pfp, ok := t.Value.(ParamFormatPreferrer); ok { + formatCode = pfp.PreferredParamFormat() + } else if _, ok := t.Value.(BinaryEncoder); ok { + formatCode = BinaryFormatCode + } + ci.oidToParamFormatCode[t.OID] = formatCode + } + + { + var formatCode int16 + if rfp, ok := t.Value.(ResultFormatPreferrer); ok { + formatCode = rfp.PreferredResultFormat() + } else if _, ok := t.Value.(BinaryDecoder); ok { + formatCode = BinaryFormatCode + } + ci.oidToResultFormatCode[t.OID] = formatCode + } + + if d, ok := t.Value.(TextDecoder); ok { + t.textDecoder = d + } + + if d, ok := t.Value.(BinaryDecoder); ok { + t.binaryDecoder = d + } + + ci.reflectTypeToDataType = nil // Invalidated by type registration +} + +// RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be +// encoded or decoded is determined by the PostgreSQL OID. But if the OID of a value to be encoded or decoded is +// unknown, this additional mapping will be used by DataTypeForValue to determine a suitable data type. +func (ci *ConnInfo) RegisterDefaultPgType(value interface{}, name string) { + ci.reflectTypeToName[reflect.TypeOf(value)] = name + ci.reflectTypeToDataType = nil // Invalidated by registering a default type +} + +func (ci *ConnInfo) DataTypeForOID(oid uint32) (*DataType, bool) { + dt, ok := ci.oidToDataType[oid] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { + dt, ok := ci.nameToDataType[name] + return dt, ok +} + +func (ci *ConnInfo) buildReflectTypeToDataType() { + ci.reflectTypeToDataType = make(map[reflect.Type]*DataType) + + for _, dt := range ci.oidToDataType { + if _, is := dt.Value.(TypeValue); !is { + ci.reflectTypeToDataType[reflect.ValueOf(dt.Value).Type()] = dt + } + } + + for reflectType, name := range ci.reflectTypeToName { + if dt, ok := ci.nameToDataType[name]; ok { + ci.reflectTypeToDataType[reflectType] = dt + } + } +} + +// DataTypeForValue finds a data type suitable for v. Use RegisterDataType to register types that can encode and decode +// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. +func (ci *ConnInfo) DataTypeForValue(v interface{}) (*DataType, bool) { + if ci.reflectTypeToDataType == nil { + ci.buildReflectTypeToDataType() + } + + if tv, ok := v.(TypeValue); ok { + dt, ok := ci.nameToDataType[tv.TypeName()] + return dt, ok + } + + dt, ok := ci.reflectTypeToDataType[reflect.TypeOf(v)] + return dt, ok +} + +func (ci *ConnInfo) ParamFormatCodeForOID(oid uint32) int16 { + fc, ok := ci.oidToParamFormatCode[oid] + if ok { + return fc + } + return TextFormatCode +} + +func (ci *ConnInfo) ResultFormatCodeForOID(oid uint32) int16 { + fc, ok := ci.oidToResultFormatCode[oid] + if ok { + return fc + } + return TextFormatCode +} + +// DeepCopy makes a deep copy of the ConnInfo. +func (ci *ConnInfo) DeepCopy() *ConnInfo { + ci2 := newConnInfo() + + for _, dt := range ci.oidToDataType { + ci2.RegisterDataType(DataType{ + Value: NewValue(dt.Value), + Name: dt.Name, + OID: dt.OID, + }) + } + + for t, n := range ci.reflectTypeToName { + ci2.reflectTypeToName[t] = n + } + + return ci2 +} + +// ScanPlan is a precompiled plan to scan into a type of destination. +type ScanPlan interface { + // Scan scans src into dst. If the dst type has changed in an incompatible way a ScanPlan should automatically + // replan and scan. + Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error +} + +type scanPlanDstBinaryDecoder struct{} + +func (scanPlanDstBinaryDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if d, ok := (dst).(BinaryDecoder); ok { + return d.DecodeBinary(ci, src) + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanDstTextDecoder struct{} + +func (plan scanPlanDstTextDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if d, ok := (dst).(TextDecoder); ok { + return d.DecodeText(ci, src) + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanDataTypeSQLScanner DataType + +func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner, ok := dst.(sql.Scanner) + if !ok { + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + + dt := (*DataType)(plan) + var err error + switch formatCode { + case BinaryFormatCode: + err = dt.binaryDecoder.DecodeBinary(ci, src) + case TextFormatCode: + err = dt.textDecoder.DecodeText(ci, src) + } + if err != nil { + return err + } + + sqlSrc, err := DatabaseSQLValue(ci, dt.Value) + if err != nil { + return err + } + return scanner.Scan(sqlSrc) +} + +type scanPlanDataTypeAssignTo DataType + +func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + dt := (*DataType)(plan) + var err error + switch formatCode { + case BinaryFormatCode: + err = dt.binaryDecoder.DecodeBinary(ci, src) + case TextFormatCode: + err = dt.textDecoder.DecodeText(ci, src) + } + if err != nil { + return err + } + + assignToErr := dt.Value.AssignTo(dst) + if assignToErr == nil { + return nil + } + + if dstPtr, ok := dst.(*interface{}); ok { + *dstPtr = dt.Value.Get() + return nil + } + + // assignToErr might have failed because the type of destination has changed + newPlan := ci.PlanScan(oid, formatCode, dst) + if newPlan, sameType := newPlan.(*scanPlanDataTypeAssignTo); !sameType { + return newPlan.Scan(ci, oid, formatCode, src, dst) + } + + return assignToErr +} + +type scanPlanSQLScanner struct{} + +func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := dst.(sql.Scanner) + if formatCode == BinaryFormatCode { + return scanner.Scan(src) + } else { + return scanner.Scan(string(src)) + } +} + +type scanPlanReflection struct{} + +func (scanPlanReflection) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + // We might be given a pointer to something that implements the decoder interface(s), + // even though the pointer itself doesn't. + refVal := reflect.ValueOf(dst) + if refVal.Kind() == reflect.Ptr && refVal.Type().Elem().Kind() == reflect.Ptr { + // If the database returned NULL, then we set dest as nil to indicate that. + if src == nil { + nilPtr := reflect.Zero(refVal.Type().Elem()) + refVal.Elem().Set(nilPtr) + return nil + } + + // We need to allocate an element, and set the destination to it + // Then we can retry as that element. + elemPtr := reflect.New(refVal.Type().Elem().Elem()) + refVal.Elem().Set(elemPtr) + + plan := ci.PlanScan(oid, formatCode, elemPtr.Interface()) + return plan.Scan(ci, oid, formatCode, src, elemPtr.Interface()) + } + + return scanUnknownType(oid, formatCode, src, dst) +} + +type scanPlanBinaryInt16 struct{} + +func (scanPlanBinaryInt16) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) + } + + if p, ok := (dst).(*int16); ok { + *p = int16(binary.BigEndian.Uint16(src)) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanBinaryInt32 struct{} + +func (scanPlanBinaryInt32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + if p, ok := (dst).(*int32); ok { + *p = int32(binary.BigEndian.Uint32(src)) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanBinaryInt64 struct{} + +func (scanPlanBinaryInt64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + if p, ok := (dst).(*int64); ok { + *p = int64(binary.BigEndian.Uint64(src)) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanBinaryFloat32 struct{} + +func (scanPlanBinaryFloat32) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) + } + + if p, ok := (dst).(*float32); ok { + n := int32(binary.BigEndian.Uint32(src)) + *p = float32(math.Float32frombits(uint32(n))) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanBinaryFloat64 struct{} + +func (scanPlanBinaryFloat64) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) + } + + if p, ok := (dst).(*float64); ok { + n := int64(binary.BigEndian.Uint64(src)) + *p = float64(math.Float64frombits(uint64(n))) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanBinaryBytes struct{} + +func (scanPlanBinaryBytes) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if p, ok := (dst).(*[]byte); ok { + *p = src + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +type scanPlanString struct{} + +func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan null into %T", dst) + } + + if p, ok := (dst).(*string); ok { + *p = string(src) + return nil + } + + newPlan := ci.PlanScan(oid, formatCode, dst) + return newPlan.Scan(ci, oid, formatCode, src, dst) +} + +// PlanScan prepares a plan to scan a value into dst. +func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan { + switch formatCode { + case BinaryFormatCode: + switch dst.(type) { + case *string: + switch oid { + case TextOID, VarcharOID: + return scanPlanString{} + } + case *int16: + if oid == Int2OID { + return scanPlanBinaryInt16{} + } + case *int32: + if oid == Int4OID { + return scanPlanBinaryInt32{} + } + case *int64: + if oid == Int8OID { + return scanPlanBinaryInt64{} + } + case *float32: + if oid == Float4OID { + return scanPlanBinaryFloat32{} + } + case *float64: + if oid == Float8OID { + return scanPlanBinaryFloat64{} + } + case *[]byte: + switch oid { + case ByteaOID, TextOID, VarcharOID, JSONOID: + return scanPlanBinaryBytes{} + } + case BinaryDecoder: + return scanPlanDstBinaryDecoder{} + } + case TextFormatCode: + switch dst.(type) { + case *string: + return scanPlanString{} + case *[]byte: + if oid != ByteaOID { + return scanPlanBinaryBytes{} + } + case TextDecoder: + return scanPlanDstTextDecoder{} + } + } + + var dt *DataType + + if oid == 0 { + if dataType, ok := ci.DataTypeForValue(dst); ok { + dt = dataType + } + } else { + if dataType, ok := ci.DataTypeForOID(oid); ok { + dt = dataType + } + } + + if dt != nil { + if _, ok := dst.(sql.Scanner); ok { + return (*scanPlanDataTypeSQLScanner)(dt) + } + return (*scanPlanDataTypeAssignTo)(dt) + } + + if _, ok := dst.(sql.Scanner); ok { + return scanPlanSQLScanner{} + } + + return scanPlanReflection{} +} + +func (ci *ConnInfo) Scan(oid uint32, formatCode int16, src []byte, dst interface{}) error { + if dst == nil { + return nil + } + + plan := ci.PlanScan(oid, formatCode, dst) + return plan.Scan(ci, oid, formatCode, src, dst) +} + +func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest interface{}) error { + switch dest := dest.(type) { + case *string: + if formatCode == BinaryFormatCode { + return fmt.Errorf("unknown oid %d in binary format cannot be scanned into %T", oid, dest) + } + *dest = string(buf) + return nil + case *[]byte: + *dest = buf + return nil + default: + if nextDst, retry := GetAssignToDstType(dest); retry { + return scanUnknownType(oid, formatCode, buf, nextDst) + } + return fmt.Errorf("unknown oid %d cannot be scanned into %T", oid, dest) + } +} + +// NewValue returns a new instance of the same type as v. +func NewValue(v Value) Value { + if tv, ok := v.(TypeValue); ok { + return tv.NewTypeValue() + } else { + return reflect.New(reflect.ValueOf(v).Elem().Type()).Interface().(Value) + } +} + +var nameValues map[string]Value + +func init() { + nameValues = map[string]Value{ + "_aclitem": &ACLItemArray{}, + "_bool": &BoolArray{}, + "_bpchar": &BPCharArray{}, + "_bytea": &ByteaArray{}, + "_cidr": &CIDRArray{}, + "_date": &DateArray{}, + "_float4": &Float4Array{}, + "_float8": &Float8Array{}, + "_inet": &InetArray{}, + "_int2": &Int2Array{}, + "_int4": &Int4Array{}, + "_int8": &Int8Array{}, + "_numeric": &NumericArray{}, + "_text": &TextArray{}, + "_timestamp": &TimestampArray{}, + "_timestamptz": &TimestamptzArray{}, + "_uuid": &UUIDArray{}, + "_varchar": &VarcharArray{}, + "_jsonb": &JSONBArray{}, + "aclitem": &ACLItem{}, + "bit": &Bit{}, + "bool": &Bool{}, + "box": &Box{}, + "bpchar": &BPChar{}, + "bytea": &Bytea{}, + "char": &QChar{}, + "cid": &CID{}, + "cidr": &CIDR{}, + "circle": &Circle{}, + "date": &Date{}, + "daterange": &Daterange{}, + "float4": &Float4{}, + "float8": &Float8{}, + "hstore": &Hstore{}, + "inet": &Inet{}, + "int2": &Int2{}, + "int4": &Int4{}, + "int4range": &Int4range{}, + "int8": &Int8{}, + "int8range": &Int8range{}, + "interval": &Interval{}, + "json": &JSON{}, + "jsonb": &JSONB{}, + "line": &Line{}, + "lseg": &Lseg{}, + "macaddr": &Macaddr{}, + "name": &Name{}, + "numeric": &Numeric{}, + "numrange": &Numrange{}, + "oid": &OIDValue{}, + "path": &Path{}, + "point": &Point{}, + "polygon": &Polygon{}, + "record": &Record{}, + "text": &Text{}, + "tid": &TID{}, + "timestamp": &Timestamp{}, + "timestamptz": &Timestamptz{}, + "tsrange": &Tsrange{}, + "_tsrange": &TsrangeArray{}, + "tstzrange": &Tstzrange{}, + "_tstzrange": &TstzrangeArray{}, + "unknown": &Unknown{}, + "uuid": &UUID{}, + "varbit": &Varbit{}, + "varchar": &Varchar{}, + "xid": &XID{}, + } +} diff --git a/vendor/github.com/jackc/pgtype/pguint32.go b/vendor/github.com/jackc/pgtype/pguint32.go new file mode 100644 index 000000000..a0e88ca2a --- /dev/null +++ b/vendor/github.com/jackc/pgtype/pguint32.go @@ -0,0 +1,162 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgio" +) + +// pguint32 is the core type that is used to implement PostgreSQL types such as +// CID and XID. +type pguint32 struct { + Uint uint32 + Status Status +} + +// Set converts from src to dst. Note that as pguint32 is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *pguint32) Set(src interface{}) error { + switch value := src.(type) { + case int64: + if value < 0 { + return fmt.Errorf("%d is less than minimum value for pguint32", value) + } + if value > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for pguint32", value) + } + *dst = pguint32{Uint: uint32(value), Status: Present} + case uint32: + *dst = pguint32{Uint: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to pguint32", value) + } + + return nil +} + +func (dst pguint32) Get() interface{} { + switch dst.Status { + case Present: + return dst.Uint + case Null: + return nil + default: + return dst.Status + } +} + +// AssignTo assigns from src to dst. Note that as pguint32 is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *pguint32) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *uint32: + if src.Status == Present { + *v = src.Uint + } else { + return fmt.Errorf("cannot assign %v into %T", src, dst) + } + case **uint32: + if src.Status == Present { + n := src.Uint + *v = &n + } else { + *v = nil + } + } + + return nil +} + +func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = pguint32{Status: Null} + return nil + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *dst = pguint32{Uint: uint32(n), Status: Present} + return nil +} + +func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = pguint32{Status: Null} + return nil + } + + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + *dst = pguint32{Uint: n, Status: Present} + return nil +} + +func (src pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, strconv.FormatUint(uint64(src.Uint), 10)...), nil +} + +func (src pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return pgio.AppendUint32(buf, src.Uint), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *pguint32) Scan(src interface{}) error { + if src == nil { + *dst = pguint32{Status: Null} + return nil + } + + switch src := src.(type) { + case uint32: + *dst = pguint32{Uint: src, Status: Present} + return nil + case int64: + *dst = pguint32{Uint: uint32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src pguint32) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Uint), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/vendor/github.com/jackc/pgtype/point.go b/vendor/github.com/jackc/pgtype/point.go new file mode 100644 index 000000000..0c799106c --- /dev/null +++ b/vendor/github.com/jackc/pgtype/point.go @@ -0,0 +1,214 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Vec2 struct { + X float64 + Y float64 +} + +type Point struct { + P Vec2 + Status Status +} + +func (dst *Point) Set(src interface{}) error { + if src == nil { + dst.Status = Null + return nil + } + err := fmt.Errorf("cannot convert %v to Point", src) + var p *Point + switch value := src.(type) { + case string: + p, err = parsePoint([]byte(value)) + case []byte: + p, err = parsePoint(value) + default: + return err + } + if err != nil { + return err + } + *dst = *p + return nil +} + +func parsePoint(src []byte) (*Point, error) { + if src == nil || bytes.Compare(src, []byte("null")) == 0 { + return &Point{Status: Null}, nil + } + + if len(src) < 5 { + return nil, fmt.Errorf("invalid length for point: %v", len(src)) + } + if src[0] == '"' && src[len(src)-1] == '"' { + src = src[1 : len(src)-1] + } + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return nil, fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return nil, err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return nil, err + } + + return &Point{P: Vec2{x, y}, Status: Present}, nil +} + +func (dst Point) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Point) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + *dst = Point{P: Vec2{x, y}, Status: Present} + return nil +} + +func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + + *dst = Point{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + Status: Present, + } + return nil +} + +func (src Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(src.P.X, 'f', -1, 64), + strconv.FormatFloat(src.P.Y, 'f', -1, 64), + )...), nil +} + +func (src Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Point) Scan(src interface{}) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Point) Value() (driver.Value, error) { + return EncodeValueText(src) +} + +func (src Point) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)) + buff.WriteByte('"') + return buff.Bytes(), nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + return nil, errBadStatus +} + +func (dst *Point) UnmarshalJSON(point []byte) error { + p, err := parsePoint(point) + if err != nil { + return err + } + *dst = *p + return nil +} diff --git a/vendor/github.com/jackc/pgtype/polygon.go b/vendor/github.com/jackc/pgtype/polygon.go new file mode 100644 index 000000000..207cadc00 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/polygon.go @@ -0,0 +1,226 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +type Polygon struct { + P []Vec2 + Status Status +} + +// Set converts src to dest. +// +// src can be nil, string, []float64, and []pgtype.Vec2. +// +// If src is string the format must be ((x1,y1),(x2,y2),...,(xn,yn)). +// Important that there are no spaces in it. +func (dst *Polygon) Set(src interface{}) error { + if src == nil { + dst.Status = Null + return nil + } + err := fmt.Errorf("cannot convert %v to Polygon", src) + var p *Polygon + switch value := src.(type) { + case string: + p, err = stringToPolygon(value) + case []Vec2: + p = &Polygon{Status: Present, P: value} + err = nil + case []float64: + p, err = float64ToPolygon(value) + default: + return err + } + if err != nil { + return err + } + *dst = *p + return nil +} + +func stringToPolygon(src string) (*Polygon, error) { + p := &Polygon{} + err := p.DecodeText(nil, []byte(src)) + return p, err +} + +func float64ToPolygon(src []float64) (*Polygon, error) { + p := &Polygon{Status: Null} + if len(src) == 0 { + return p, nil + } + if len(src)%2 != 0 { + p.Status = Undefined + return p, fmt.Errorf("invalid length for polygon: %v", len(src)) + } + p.Status = Present + p.P = make([]Vec2, 0) + for i := 0; i < len(src); i += 2 { + p.P = append(p.P, Vec2{X: src[i], Y: src[i+1]}) + } + return p, nil +} + +func (dst Polygon) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Polygon) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Polygon: %v", len(src)) + } + + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + *dst = Polygon{P: points, Status: Present} + return nil +} + +func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for Polygon: %v", len(src)) + } + + pointCount := int(binary.BigEndian.Uint32(src)) + rp := 4 + + if 4+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + *dst = Polygon{ + P: points, + Status: Present, + } + return nil +} + +func (src Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, '(') + + for i, p := range src.P { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, fmt.Sprintf(`(%s,%s)`, + strconv.FormatFloat(p.X, 'f', -1, 64), + strconv.FormatFloat(p.Y, 'f', -1, 64), + )...) + } + + return append(buf, ')'), nil +} + +func (src Polygon) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt32(buf, int32(len(src.P))) + + for _, p := range src.P { + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Polygon) Scan(src interface{}) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Polygon) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/qchar.go b/vendor/github.com/jackc/pgtype/qchar.go new file mode 100644 index 000000000..574f6066c --- /dev/null +++ b/vendor/github.com/jackc/pgtype/qchar.go @@ -0,0 +1,152 @@ +package pgtype + +import ( + "fmt" + "math" + "strconv" +) + +// QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C +// language's char type, or Go's byte type. (Note that the name in PostgreSQL +// itself is "char", in double-quotes, and not char.) It gets used a lot in +// PostgreSQL's system tables to hold a single ASCII character value (eg +// pg_class.relkind). It is named Qchar for quoted char to disambiguate from SQL +// standard type char. +// +// Not all possible values of QChar are representable in the text format. +// Therefore, QChar does not implement TextEncoder and TextDecoder. In +// addition, database/sql Scanner and database/sql/driver Value are not +// implemented. +type QChar struct { + Int int8 + Status Status +} + +func (dst *QChar) Set(src interface{}) error { + if src == nil { + *dst = QChar{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case int8: + *dst = QChar{Int: value, Status: Present} + case uint8: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int16: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint16: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int32: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint32: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int64: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint64: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 8) + if err != nil { + return err + } + *dst = QChar{Int: int8(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to QChar", value) + } + + return nil +} + +func (dst QChar) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + +func (src *QChar) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) +} + +func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = QChar{Status: Null} + return nil + } + + if len(src) != 1 { + return fmt.Errorf(`invalid length for "char": %v`, len(src)) + } + + *dst = QChar{Int: int8(src[0]), Status: Present} + return nil +} + +func (src QChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, byte(src.Int)), nil +} diff --git a/vendor/github.com/jackc/pgtype/range.go b/vendor/github.com/jackc/pgtype/range.go new file mode 100644 index 000000000..e999f6a91 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/range.go @@ -0,0 +1,277 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type BoundType byte + +const ( + Inclusive = BoundType('i') + Exclusive = BoundType('e') + Unbounded = BoundType('U') + Empty = BoundType('E') +) + +func (bt BoundType) String() string { + return string(bt) +} + +type UntypedTextRange struct { + Lower string + Upper string + LowerType BoundType + UpperType BoundType +} + +func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { + utr := &UntypedTextRange{} + if src == "empty" { + utr.LowerType = Empty + utr.UpperType = Empty + return utr, nil + } + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower bound: %v", err) + } + switch r { + case '(': + utr.LowerType = Exclusive + case '[': + utr.LowerType = Inclusive + default: + return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + buf.UnreadRune() + + if r == ',' { + utr.LowerType = Unbounded + } else { + utr.Lower, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing range separator: %v", err) + } + if r != ',' { + return nil, fmt.Errorf("missing range separator: %v", r) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + + if r == ')' || r == ']' { + utr.UpperType = Unbounded + } else { + buf.UnreadRune() + utr.Upper, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing upper bound: %v", err) + } + switch r { + case ')': + utr.UpperType = Exclusive + case ']': + utr.UpperType = Inclusive + default: + return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + return utr, nil +} + +func rangeParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return rangeParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case ',', '[', ']', '(', ')': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + if r != '"' { + buf.UnreadRune() + return s.String(), nil + } + } + s.WriteRune(r) + } +} + +type UntypedBinaryRange struct { + Lower []byte + Upper []byte + LowerType BoundType + UpperType BoundType +} + +// 0 = () = 00000 +// 1 = empty = 00001 +// 2 = [) = 00010 +// 4 = (] = 00100 +// 6 = [] = 00110 +// 8 = ) = 01000 +// 12 = ] = 01100 +// 16 = ( = 10000 +// 18 = [ = 10010 +// 24 = = 11000 + +const emptyMask = 1 +const lowerInclusiveMask = 2 +const upperInclusiveMask = 4 +const lowerUnboundedMask = 8 +const upperUnboundedMask = 16 + +func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { + ubr := &UntypedBinaryRange{} + + if len(src) == 0 { + return nil, fmt.Errorf("range too short: %v", len(src)) + } + + rangeType := src[0] + rp := 1 + + if rangeType&emptyMask > 0 { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + } + ubr.LowerType = Empty + ubr.UpperType = Empty + return ubr, nil + } + + if rangeType&lowerInclusiveMask > 0 { + ubr.LowerType = Inclusive + } else if rangeType&lowerUnboundedMask > 0 { + ubr.LowerType = Unbounded + } else { + ubr.LowerType = Exclusive + } + + if rangeType&upperInclusiveMask > 0 { + ubr.UpperType = Inclusive + } else if rangeType&upperUnboundedMask > 0 { + ubr.UpperType = Unbounded + } else { + ubr.UpperType = Exclusive + } + + if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + } + return ubr, nil + } + + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + val := src[rp : rp+valueLen] + rp += valueLen + + if ubr.LowerType != Unbounded { + ubr.Lower = val + } else { + ubr.Upper = val + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + return ubr, nil + } + + if ubr.UpperType != Unbounded { + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + ubr.Upper = src[rp : rp+valueLen] + rp += valueLen + } + + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + + return ubr, nil + +} diff --git a/vendor/github.com/jackc/pgtype/record.go b/vendor/github.com/jackc/pgtype/record.go new file mode 100644 index 000000000..718c35702 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/record.go @@ -0,0 +1,126 @@ +package pgtype + +import ( + "fmt" + "reflect" +) + +// Record is the generic PostgreSQL record type such as is created with the +// "row" function. Record only implements BinaryEncoder and Value. The text +// format output format from PostgreSQL does not include type information and is +// therefore impossible to decode. No encoders are implemented because +// PostgreSQL does not support input of generic records. +type Record struct { + Fields []Value + Status Status +} + +func (dst *Record) Set(src interface{}) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case []Value: + *dst = Record{Fields: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to Record", src) + } + + return nil +} + +func (dst Record) Get() interface{} { + switch dst.Status { + case Present: + return dst.Fields + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Record) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *[]Value: + *v = make([]Value, len(src.Fields)) + copy(*v, src.Fields) + return nil + case *[]interface{}: + *v = make([]interface{}, len(src.Fields)) + for i := range *v { + (*v)[i] = src.Fields[i].Get() + } + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func prepareNewBinaryDecoder(ci *ConnInfo, fieldOID uint32, v *Value) (BinaryDecoder, error) { + var binaryDecoder BinaryDecoder + + if dt, ok := ci.DataTypeForOID(fieldOID); ok { + binaryDecoder, _ = dt.Value.(BinaryDecoder) + } else { + return nil, fmt.Errorf("unknown oid while decoding record: %v", fieldOID) + } + + if binaryDecoder == nil { + return nil, fmt.Errorf("no binary decoder registered for: %v", fieldOID) + } + + // Duplicate struct to scan into + binaryDecoder = reflect.New(reflect.ValueOf(binaryDecoder).Elem().Type()).Interface().(BinaryDecoder) + *v = binaryDecoder.(Value) + return binaryDecoder, nil +} + +func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + + scanner := NewCompositeBinaryScanner(ci, src) + + fields := make([]Value, scanner.FieldCount()) + + for i := 0; scanner.Next(); i++ { + binaryDecoder, err := prepareNewBinaryDecoder(ci, scanner.OID(), &fields[i]) + if err != nil { + return err + } + + if err = binaryDecoder.DecodeBinary(ci, scanner.Bytes()); err != nil { + return err + } + } + + if scanner.Err() != nil { + return scanner.Err() + } + + *dst = Record{Fields: fields, Status: Present} + + return nil +} diff --git a/vendor/github.com/jackc/pgtype/text.go b/vendor/github.com/jackc/pgtype/text.go new file mode 100644 index 000000000..6b01d1b49 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/text.go @@ -0,0 +1,182 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/json" + "fmt" +) + +type Text struct { + String string + Status Status +} + +func (dst *Text) Set(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case string: + *dst = Text{String: value, Status: Present} + case *string: + if value == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: *value, Status: Present} + } + case []byte: + if value == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: string(value), Status: Present} + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Text", value) + } + + return nil +} + +func (dst Text) Get() interface{} { + switch dst.Status { + case Present: + return dst.String + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Text) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + case *[]byte: + *v = make([]byte, len(src.String)) + copy(*v, src.String) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (Text) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + + *dst = Text{String: string(src), Status: Present} + return nil +} + +func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) +} + +func (Text) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.String...), nil +} + +func (src Text) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Text) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src Text) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return json.Marshal(src.String) + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *Text) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: *s, Status: Present} + } + + return nil +} diff --git a/vendor/github.com/jackc/pgtype/text_array.go b/vendor/github.com/jackc/pgtype/text_array.go new file mode 100644 index 000000000..2461966b3 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/text_array.go @@ -0,0 +1,517 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type TextArray struct { + Elements []Text + Dimensions []ArrayDimension + Status Status +} + +func (dst *TextArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TextArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TextArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TextArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Text: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + *dst = TextArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TextArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for TextArray", src) + } + if elementsLength == 0 { + *dst = TextArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to TextArray", src) + } + + *dst = TextArray{ + Elements: make([]Text, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Text, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to TextArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *TextArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to TextArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in TextArray", err) + } + index++ + + return index, nil +} + +func (dst TextArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TextArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *TextArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from TextArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from TextArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TextArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Text + + if len(uta.Elements) > 0 { + elements = make([]Text, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Text + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TextArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TextArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TextArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Text, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TextArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("text"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "text") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TextArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TextArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/tid.go b/vendor/github.com/jackc/pgtype/tid.go new file mode 100644 index 000000000..4bb57f643 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/tid.go @@ -0,0 +1,156 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + + "github.com/jackc/pgio" +) + +// TID is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type TID struct { + BlockNumber uint32 + OffsetNumber uint16 + Status Status +} + +func (dst *TID) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to TID", src) +} + +func (dst TID) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TID) AssignTo(dst interface{}) error { + if src.Status == Present { + switch v := dst.(type) { + case *string: + *v = fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + } + + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for tid") + } + + blockNumber, err := strconv.ParseUint(parts[0], 10, 32) + if err != nil { + return err + } + + offsetNumber, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return err + } + + *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} + return nil +} + +func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + if len(src) != 6 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + *dst = TID{ + BlockNumber: binary.BigEndian.Uint32(src), + OffsetNumber: binary.BigEndian.Uint16(src[4:]), + Status: Present, + } + return nil +} + +func (src TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = append(buf, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)...) + return buf, nil +} + +func (src TID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendUint32(buf, src.BlockNumber) + buf = pgio.AppendUint16(buf, src.OffsetNumber) + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TID) Scan(src interface{}) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TID) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/time.go b/vendor/github.com/jackc/pgtype/time.go new file mode 100644 index 000000000..f7a28870a --- /dev/null +++ b/vendor/github.com/jackc/pgtype/time.go @@ -0,0 +1,231 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "time" + + "github.com/jackc/pgio" +) + +// Time represents the PostgreSQL time type. The PostgreSQL time is a time of day without time zone. +// +// Time is represented as the number of microseconds since midnight in the same way that PostgreSQL does. Other time +// and date types in pgtype can use time.Time as the underlying representation. However, pgtype.Time type cannot due +// to needing to handle 24:00:00. time.Time converts that to 00:00:00 on the following day. +type Time struct { + Microseconds int64 // Number of microseconds since midnight + Status Status +} + +// Set converts src into a Time and stores in dst. +func (dst *Time) Set(src interface{}) error { + if src == nil { + *dst = Time{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case time.Time: + usec := int64(value.Hour())*microsecondsPerHour + + int64(value.Minute())*microsecondsPerMinute + + int64(value.Second())*microsecondsPerSecond + + int64(value.Nanosecond())/1000 + *dst = Time{Microseconds: usec, Status: Present} + case *time.Time: + if value == nil { + *dst = Time{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Time", value) + } + + return nil +} + +func (dst Time) Get() interface{} { + switch dst.Status { + case Present: + return dst.Microseconds + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Time) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. + var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 + if src.Microseconds > maxRepresentableByTime { + return fmt.Errorf("%d microseconds cannot be represented as time.Time", src.Microseconds) + } + + usec := src.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + ns := usec * 1000 + *v = time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +// DecodeText decodes from src into dst. +func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Time{Status: Null} + return nil + } + + s := string(src) + + if len(s) < 8 { + return fmt.Errorf("cannot decode %v into Time", s) + } + + hours, err := strconv.ParseInt(s[0:2], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec := hours * microsecondsPerHour + + minutes, err := strconv.ParseInt(s[3:5], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec += minutes * microsecondsPerMinute + + seconds, err := strconv.ParseInt(s[6:8], 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + usec += seconds * microsecondsPerSecond + + if len(s) > 9 { + fraction := s[9:] + n, err := strconv.ParseInt(fraction, 10, 64) + if err != nil { + return fmt.Errorf("cannot decode %v into Time", s) + } + + for i := len(fraction); i < 6; i++ { + n *= 10 + } + + usec += n + } + + *dst = Time{Microseconds: usec, Status: Present} + + return nil +} + +// DecodeBinary decodes from src into dst. +func (dst *Time) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Time{Status: Null} + return nil + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for time: %v", len(src)) + } + + usec := int64(binary.BigEndian.Uint64(src)) + *dst = Time{Microseconds: usec, Status: Present} + + return nil +} + +// EncodeText writes the text encoding of src into w. +func (src Time) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + usec := src.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + + s := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, usec) + + return append(buf, s...), nil +} + +// EncodeBinary writes the binary encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src Time) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return pgio.AppendInt64(buf, src.Microseconds), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Time) Scan(src interface{}) error { + if src == nil { + *dst = Time{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + return dst.Set(src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Time) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/timestamp.go b/vendor/github.com/jackc/pgtype/timestamp.go new file mode 100644 index 000000000..466441158 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/timestamp.go @@ -0,0 +1,241 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "time" + + "github.com/jackc/pgio" +) + +const pgTimestampFormat = "2006-01-02 15:04:05.999999999" + +// Timestamp represents the PostgreSQL timestamp type. The PostgreSQL +// timestamp does not have a time zone. This presents a problem when +// translating to and from time.Time which requires a time zone. It is highly +// recommended to use timestamptz whenever possible. Timestamp methods either +// convert to UTC or return an error on non-UTC times. +type Timestamp struct { + Time time.Time // Time must always be in UTC. + Status Status + InfinityModifier InfinityModifier +} + +// Set converts src into a Timestamp and stores in dst. If src is a +// time.Time in a non-UTC time zone, the time zone is discarded. +func (dst *Timestamp) Set(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case time.Time: + *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} + case *time.Time: + if value == nil { + *dst = Timestamp{Status: Null} + } else { + return dst.Set(*value) + } + case InfinityModifier: + *dst = Timestamp{InfinityModifier: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamp", value) + } + + return nil +} + +func (dst Timestamp) Get() interface{} { + switch dst.Status { + case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Timestamp) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +// DecodeText decodes from src into dst. The decoded time is considered to +// be in UTC. +func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + sbuf := string(src) + switch sbuf { + case "infinity": + *dst = Timestamp{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} + default: + tim, err := time.Parse(pgTimestampFormat, sbuf) + if err != nil { + return err + } + + *dst = Timestamp{Time: tim, Status: Present} + } + + return nil +} + +// DecodeBinary decodes from src into dst. The decoded time is considered to +// be in UTC. +func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamp: %v", len(src)) + } + + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *dst = Timestamp{Status: Present, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} + default: + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000).UTC() + *dst = Timestamp{Time: tim, Status: Present} + } + + return nil +} + +// EncodeText writes the text encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + if src.Time.Location() != time.UTC { + return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Truncate(time.Microsecond).Format(pgTimestampFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return append(buf, s...), nil +} + +// EncodeBinary writes the binary encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + if src.Time.Location() != time.UTC { + return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") + } + + var microsecSinceY2K int64 + switch src.InfinityModifier { + case None: + microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + return pgio.AppendInt64(buf, microsecSinceY2K), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamp) Scan(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + *dst = Timestamp{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamp) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/vendor/github.com/jackc/pgtype/timestamp_array.go b/vendor/github.com/jackc/pgtype/timestamp_array.go new file mode 100644 index 000000000..e12481e3d --- /dev/null +++ b/vendor/github.com/jackc/pgtype/timestamp_array.go @@ -0,0 +1,518 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + "time" + + "github.com/jackc/pgio" +) + +type TimestampArray struct { + Elements []Timestamp + Dimensions []ArrayDimension + Status Status +} + +func (dst *TimestampArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TimestampArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []time.Time: + if value == nil { + *dst = TimestampArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestampArray{Status: Present} + } else { + elements := make([]Timestamp, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestampArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*time.Time: + if value == nil { + *dst = TimestampArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestampArray{Status: Present} + } else { + elements := make([]Timestamp, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestampArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Timestamp: + if value == nil { + *dst = TimestampArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestampArray{Status: Present} + } else { + *dst = TimestampArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TimestampArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for TimestampArray", src) + } + if elementsLength == 0 { + *dst = TimestampArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to TimestampArray", src) + } + + *dst = TimestampArray{ + Elements: make([]Timestamp, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamp, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to TimestampArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *TimestampArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to TimestampArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in TimestampArray", err) + } + index++ + + return index, nil +} + +func (dst TimestampArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TimestampArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *TimestampArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from TimestampArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from TimestampArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestampArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Timestamp + + if len(uta.Elements) > 0 { + elements = make([]Timestamp, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Timestamp + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TimestampArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestampArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TimestampArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Timestamp, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TimestampArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("timestamp"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "timestamp") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TimestampArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TimestampArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/timestamptz.go b/vendor/github.com/jackc/pgtype/timestamptz.go new file mode 100644 index 000000000..e0743060b --- /dev/null +++ b/vendor/github.com/jackc/pgtype/timestamptz.go @@ -0,0 +1,294 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "encoding/json" + "fmt" + "time" + + "github.com/jackc/pgio" +) + +const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" +const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" +const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +const ( + negativeInfinityMicrosecondOffset = -9223372036854775808 + infinityMicrosecondOffset = 9223372036854775807 +) + +type Timestamptz struct { + Time time.Time + Status Status + InfinityModifier InfinityModifier +} + +func (dst *Timestamptz) Set(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case time.Time: + *dst = Timestamptz{Time: value, Status: Present} + case *time.Time: + if value == nil { + *dst = Timestamptz{Status: Null} + } else { + return dst.Set(*value) + } + case InfinityModifier: + *dst = Timestamptz{InfinityModifier: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamptz", value) + } + + return nil +} + +func (dst Timestamptz) Get() interface{} { + switch dst.Status { + case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Timestamptz) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return fmt.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + sbuf := string(src) + switch sbuf { + case "infinity": + *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + var format string + if len(sbuf) >= 9 && (sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+') { + format = pgTimestamptzSecondFormat + } else if len(sbuf) >= 6 && (sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+') { + format = pgTimestamptzMinuteFormat + } else { + format = pgTimestamptzHourFormat + } + + tim, err := time.Parse(format, sbuf) + if err != nil { + return err + } + + *dst = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", len(src)) + } + + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + *dst = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (src Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.UTC().Truncate(time.Microsecond).Format(pgTimestamptzSecondFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return append(buf, s...), nil +} + +func (src Timestamptz) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var microsecSinceY2K int64 + switch src.InfinityModifier { + case None: + microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + return pgio.AppendInt64(buf, microsecSinceY2K), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamptz) Scan(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + case time.Time: + *dst = Timestamptz{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamptz) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} + +func (src Timestamptz) MarshalJSON() ([]byte, error) { + switch src.Status { + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + if src.Status != Present { + return nil, errBadStatus + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format(time.RFC3339Nano) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (dst *Timestamptz) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + switch *s { + case "infinity": + *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz + tim, err := time.Parse(time.RFC3339Nano, *s) + if err != nil { + return err + } + + *dst = Timestamptz{Time: tim, Status: Present} + } + + return nil +} diff --git a/vendor/github.com/jackc/pgtype/timestamptz_array.go b/vendor/github.com/jackc/pgtype/timestamptz_array.go new file mode 100644 index 000000000..a3b4b263d --- /dev/null +++ b/vendor/github.com/jackc/pgtype/timestamptz_array.go @@ -0,0 +1,518 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + "time" + + "github.com/jackc/pgio" +) + +type TimestamptzArray struct { + Elements []Timestamptz + Dimensions []ArrayDimension + Status Status +} + +func (dst *TimestamptzArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TimestamptzArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []time.Time: + if value == nil { + *dst = TimestamptzArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestamptzArray{Status: Present} + } else { + elements := make([]Timestamptz, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestamptzArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*time.Time: + if value == nil { + *dst = TimestamptzArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestamptzArray{Status: Present} + } else { + elements := make([]Timestamptz, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = TimestamptzArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Timestamptz: + if value == nil { + *dst = TimestamptzArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestamptzArray{Status: Present} + } else { + *dst = TimestamptzArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TimestamptzArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for TimestamptzArray", src) + } + if elementsLength == 0 { + *dst = TimestamptzArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to TimestamptzArray", src) + } + + *dst = TimestamptzArray{ + Elements: make([]Timestamptz, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Timestamptz, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to TimestamptzArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *TimestamptzArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to TimestamptzArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in TimestamptzArray", err) + } + index++ + + return index, nil +} + +func (dst TimestamptzArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TimestamptzArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]time.Time: + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*time.Time: + *v = make([]*time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *TimestamptzArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from TimestamptzArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from TimestamptzArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestamptzArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Timestamptz + + if len(uta.Elements) > 0 { + elements = make([]Timestamptz, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Timestamptz + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TimestamptzArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TimestamptzArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TimestamptzArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Timestamptz, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TimestamptzArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("timestamptz"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "timestamptz") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TimestamptzArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TimestamptzArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/tsrange.go b/vendor/github.com/jackc/pgtype/tsrange.go new file mode 100644 index 000000000..19ecf446a --- /dev/null +++ b/vendor/github.com/jackc/pgtype/tsrange.go @@ -0,0 +1,267 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Tsrange struct { + Lower Timestamp + Upper Timestamp + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Tsrange) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + switch value := src.(type) { + case Tsrange: + *dst = value + case *Tsrange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Tsrange", src) + } + + return nil +} + +func (dst Tsrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Tsrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Tsrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Tsrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Tsrange) Scan(src interface{}) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tsrange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/tsrange_array.go b/vendor/github.com/jackc/pgtype/tsrange_array.go new file mode 100644 index 000000000..c64048eb0 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/tsrange_array.go @@ -0,0 +1,470 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type TsrangeArray struct { + Elements []Tsrange + Dimensions []ArrayDimension + Status Status +} + +func (dst *TsrangeArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TsrangeArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []Tsrange: + if value == nil { + *dst = TsrangeArray{Status: Null} + } else if len(value) == 0 { + *dst = TsrangeArray{Status: Present} + } else { + *dst = TsrangeArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TsrangeArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for TsrangeArray", src) + } + if elementsLength == 0 { + *dst = TsrangeArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to TsrangeArray", src) + } + + *dst = TsrangeArray{ + Elements: make([]Tsrange, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Tsrange, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to TsrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *TsrangeArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to TsrangeArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in TsrangeArray", err) + } + index++ + + return index, nil +} + +func (dst TsrangeArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TsrangeArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]Tsrange: + *v = make([]Tsrange, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *TsrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from TsrangeArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from TsrangeArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *TsrangeArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TsrangeArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Tsrange + + if len(uta.Elements) > 0 { + elements = make([]Tsrange, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Tsrange + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TsrangeArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TsrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TsrangeArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TsrangeArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Tsrange, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TsrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src TsrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src TsrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("tsrange"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "tsrange") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TsrangeArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TsrangeArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/tstzrange.go b/vendor/github.com/jackc/pgtype/tstzrange.go new file mode 100644 index 000000000..255763081 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/tstzrange.go @@ -0,0 +1,267 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgio" +) + +type Tstzrange struct { + Lower Timestamptz + Upper Timestamptz + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Tstzrange) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + switch value := src.(type) { + case Tstzrange: + *dst = value + case *Tstzrange: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to Tstzrange", src) + } + + return nil +} + +func (dst Tstzrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Tstzrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Tstzrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Tstzrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Tstzrange) Scan(src interface{}) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tstzrange) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/tstzrange_array.go b/vendor/github.com/jackc/pgtype/tstzrange_array.go new file mode 100644 index 000000000..a216820a3 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/tstzrange_array.go @@ -0,0 +1,470 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type TstzrangeArray struct { + Elements []Tstzrange + Dimensions []ArrayDimension + Status Status +} + +func (dst *TstzrangeArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = TstzrangeArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []Tstzrange: + if value == nil { + *dst = TstzrangeArray{Status: Null} + } else if len(value) == 0 { + *dst = TstzrangeArray{Status: Present} + } else { + *dst = TstzrangeArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = TstzrangeArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for TstzrangeArray", src) + } + if elementsLength == 0 { + *dst = TstzrangeArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to TstzrangeArray", src) + } + + *dst = TstzrangeArray{ + Elements: make([]Tstzrange, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Tstzrange, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to TstzrangeArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *TstzrangeArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to TstzrangeArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in TstzrangeArray", err) + } + index++ + + return index, nil +} + +func (dst TstzrangeArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *TstzrangeArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]Tstzrange: + *v = make([]Tstzrange, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *TstzrangeArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from TstzrangeArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from TstzrangeArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *TstzrangeArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TstzrangeArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Tstzrange + + if len(uta.Elements) > 0 { + elements = make([]Tstzrange, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Tstzrange + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TstzrangeArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TstzrangeArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = TstzrangeArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TstzrangeArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Tstzrange, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = TstzrangeArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src TstzrangeArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src TstzrangeArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("tstzrange"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "tstzrange") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *TstzrangeArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src TstzrangeArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/typed_array.go.erb b/vendor/github.com/jackc/pgtype/typed_array.go.erb new file mode 100644 index 000000000..5788626b4 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/typed_array.go.erb @@ -0,0 +1,494 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgio" +) + +type <%= pgtype_array_type %> struct { + Elements []<%= pgtype_element_type %> + Dimensions []ArrayDimension + Status Status +} + +func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + <% go_array_types.split(",").each do |t| %> + <% if t != "[]#{pgtype_element_type}" %> + case <%= t %>: + if value == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + } else if len(value) == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + } else { + elements := make([]<%= pgtype_element_type %>, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = <%= pgtype_array_type %>{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + <% end %> + <% end %> + case []<%= pgtype_element_type %>: + if value == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + } else if len(value) == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + } else { + *dst = <%= pgtype_array_type %>{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status : Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for <%= pgtype_array_type %>", src) + } + if elementsLength == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to <%= pgtype_array_type %>", src) + } + + *dst = <%= pgtype_array_type %> { + Elements: make([]<%= pgtype_element_type %>, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]<%= pgtype_element_type %>, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to <%= pgtype_array_type %>, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *<%= pgtype_array_type %>) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to <%= pgtype_array_type %>") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in <%= pgtype_array_type %>", err) + } + index++ + + return index, nil +} + +func (dst <%= pgtype_array_type %>) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1{ + // Attempt to match to select common types: + switch v := dst.(type) { + <% go_array_types.split(",").each do |t| %> + case *<%= t %>: + *v = make(<%= t %>, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + <% end %> + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr(){ + return 0, fmt.Errorf("cannot assign all values from <%= pgtype_array_type %>") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from <%= pgtype_array_type %>") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []<%= pgtype_element_type %> + + if len(uta.Elements) > 0 { + elements = make([]<%= pgtype_element_type %>, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem <%= pgtype_element_type %> + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +<% if binary_format == "true" %> +func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = <%= pgtype_array_type %>{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]<%= pgtype_element_type %>, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp:rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} +<% end %> + +func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `<%= text_null %>`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +<% if binary_format == "true" %> + func (src <%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil + } +<% end %> + +// Scan implements the database/sql Scanner interface. +func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src <%= pgtype_array_type %>) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/typed_array_gen.sh b/vendor/github.com/jackc/pgtype/typed_array_gen.sh new file mode 100644 index 000000000..ea28be077 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/typed_array_gen.sh @@ -0,0 +1,28 @@ +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go +erb pgtype_array_type=TsrangeArray pgtype_element_type=Tsrange go_array_types=[]Tsrange element_type_name=tsrange text_null=NULL binary_format=true typed_array.go.erb > tsrange_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go +erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go +erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go +erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go +erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]string,[][]byte element_type_name=jsonb text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go + +# While the binary format is theoretically possible it is only practical to use the text format. +erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go + +goimports -w *_array.go diff --git a/vendor/github.com/jackc/pgtype/typed_range.go.erb b/vendor/github.com/jackc/pgtype/typed_range.go.erb new file mode 100644 index 000000000..5625587ae --- /dev/null +++ b/vendor/github.com/jackc/pgtype/typed_range.go.erb @@ -0,0 +1,269 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgio" +) + +type <%= range_type %> struct { + Lower <%= element_type %> + Upper <%= element_type %> + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *<%= range_type %>) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + switch value := src.(type) { + case <%= range_type %>: + *dst = value + case *<%= range_type %>: + *dst = *value + case string: + return dst.DecodeText(nil, []byte(value)) + default: + return fmt.Errorf("cannot convert %v to <%= range_type %>", src) + } + + return nil +} + +func (dst <%= range_type %>) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *<%= range_type %>) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = <%= range_type %>{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = <%= range_type %>{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + buf = append(buf, '(') + case Inclusive: + buf = append(buf, '[') + case Empty: + return append(buf, "empty"...), nil + default: + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + var err error + + if src.LowerType != Unbounded { + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + buf = append(buf, ',') + + if src.UpperType != Unbounded { + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + buf = append(buf, ')') + case Inclusive: + buf = append(buf, ']') + default: + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return buf, nil +} + +func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + return append(buf, emptyMask), nil + default: + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + buf = append(buf, rangeType) + + var err error + + if src.LowerType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + if src.UpperType != Unbounded { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *<%= range_type %>) Scan(src interface{}) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src <%= range_type %>) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/typed_range_gen.sh b/vendor/github.com/jackc/pgtype/typed_range_gen.sh new file mode 100644 index 000000000..bedda2925 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/typed_range_gen.sh @@ -0,0 +1,7 @@ +erb range_type=Int4range element_type=Int4 typed_range.go.erb > int4range.go +erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go +erb range_type=Tsrange element_type=Timestamp typed_range.go.erb > tsrange.go +erb range_type=Tstzrange element_type=Timestamptz typed_range.go.erb > tstzrange.go +erb range_type=Daterange element_type=Date typed_range.go.erb > daterange.go +erb range_type=Numrange element_type=Numeric typed_range.go.erb > numrange.go +goimports -w *range.go diff --git a/vendor/github.com/jackc/pgtype/unknown.go b/vendor/github.com/jackc/pgtype/unknown.go new file mode 100644 index 000000000..c591b7083 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/unknown.go @@ -0,0 +1,44 @@ +package pgtype + +import "database/sql/driver" + +// Unknown represents the PostgreSQL unknown type. It is either a string literal +// or NULL. It is used when PostgreSQL does not know the type of a value. In +// general, this will only be used in pgx when selecting a null value without +// type information. e.g. SELECT NULL; +type Unknown struct { + String string + Status Status +} + +func (dst *Unknown) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst Unknown) Get() interface{} { + return (Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Unknown is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Unknown) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Unknown) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Unknown) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/vendor/github.com/jackc/pgtype/uuid.go b/vendor/github.com/jackc/pgtype/uuid.go new file mode 100644 index 000000000..fa0be07fe --- /dev/null +++ b/vendor/github.com/jackc/pgtype/uuid.go @@ -0,0 +1,230 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/hex" + "fmt" +) + +type UUID struct { + Bytes [16]byte + Status Status +} + +func (dst *UUID) Set(src interface{}) error { + if src == nil { + *dst = UUID{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case [16]byte: + *dst = UUID{Bytes: value, Status: Present} + case []byte: + if value != nil { + if len(value) != 16 { + return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + } + *dst = UUID{Status: Present} + copy(dst.Bytes[:], value) + } else { + *dst = UUID{Status: Null} + } + case string: + uuid, err := parseUUID(value) + if err != nil { + return err + } + *dst = UUID{Bytes: uuid, Status: Present} + case *string: + if value == nil { + *dst = UUID{Status: Null} + } else { + return dst.Set(*value) + } + default: + if originalSrc, ok := underlyingUUIDType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to UUID", value) + } + + return nil +} + +func (dst UUID) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bytes + case Null: + return nil + default: + return dst.Status + } +} + +func (src *UUID) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *[16]byte: + *v = src.Bytes + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.Bytes[:]) + return nil + case *string: + *v = encodeUUID(src.Bytes) + return nil + default: + if nextDst, retry := GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot assign %v into %T", src, dst) +} + +// parseUUID converts a string UUID in standard form to a byte array. +func parseUUID(src string) (dst [16]byte, err error) { + switch len(src) { + case 36: + src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] + case 32: + // dashes already stripped, assume valid + default: + // assume invalid. + return dst, fmt.Errorf("cannot parse UUID %v", src) + } + + buf, err := hex.DecodeString(src) + if err != nil { + return dst, err + } + + copy(dst[:], buf) + return dst, err +} + +// encodeUUID converts a uuid byte array to UUID standard string form. +func encodeUUID(src [16]byte) string { + return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) +} + +func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUID{Status: Null} + return nil + } + + if len(src) != 36 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + + buf, err := parseUUID(string(src)) + if err != nil { + return err + } + + *dst = UUID{Bytes: buf, Status: Present} + return nil +} + +func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUID{Status: Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + + *dst = UUID{Status: Present} + copy(dst.Bytes[:], src) + return nil +} + +func (src UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, encodeUUID(src.Bytes)...), nil +} + +func (src UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.Bytes[:]...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUID) Scan(src interface{}) error { + if src == nil { + *dst = UUID{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src UUID) Value() (driver.Value, error) { + return EncodeValueText(src) +} + +func (src UUID) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(encodeUUID(src.Bytes)) + buff.WriteByte('"') + return buff.Bytes(), nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + return nil, errBadStatus +} + +func (dst *UUID) UnmarshalJSON(src []byte) error { + if bytes.Compare(src, []byte("null")) == 0 { + return dst.Set(nil) + } + if len(src) != 38 { + return fmt.Errorf("invalid length for UUID: %v", len(src)) + } + return dst.Set(string(src[1 : len(src)-1])) +} diff --git a/vendor/github.com/jackc/pgtype/uuid_array.go b/vendor/github.com/jackc/pgtype/uuid_array.go new file mode 100644 index 000000000..00721ef93 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/uuid_array.go @@ -0,0 +1,573 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type UUIDArray struct { + Elements []UUID + Dimensions []ArrayDimension + Status Status +} + +func (dst *UUIDArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = UUIDArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case [][16]byte: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case [][]byte: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []string: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []UUID: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + *dst = UUIDArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = UUIDArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for UUIDArray", src) + } + if elementsLength == 0 { + *dst = UUIDArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to UUIDArray", src) + } + + *dst = UUIDArray{ + Elements: make([]UUID, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]UUID, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to UUIDArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *UUIDArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to UUIDArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in UUIDArray", err) + } + index++ + + return index, nil +} + +func (dst UUIDArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *UUIDArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[][16]byte: + *v = make([][16]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *UUIDArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from UUIDArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from UUIDArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUIDArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []UUID + + if len(uta.Elements) > 0 { + elements = make([]UUID, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem UUID + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = UUIDArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUIDArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = UUIDArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]UUID, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = UUIDArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("uuid"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "uuid") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUIDArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src UUIDArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/varbit.go b/vendor/github.com/jackc/pgtype/varbit.go new file mode 100644 index 000000000..f24dc5bcf --- /dev/null +++ b/vendor/github.com/jackc/pgtype/varbit.go @@ -0,0 +1,133 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + + "github.com/jackc/pgio" +) + +type Varbit struct { + Bytes []byte + Len int32 // Number of bits + Status Status +} + +func (dst *Varbit) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Varbit", src) +} + +func (dst Varbit) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Varbit) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + bitLen := len(src) + byteLen := bitLen / 8 + if bitLen%8 > 0 { + byteLen++ + } + buf := make([]byte, byteLen) + + for i, b := range src { + if b == '1' { + byteIdx := i / 8 + bitIdx := uint(i % 8) + buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) + } + } + + *dst = Varbit{Bytes: buf, Len: int32(bitLen), Status: Present} + return nil +} + +func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + if len(src) < 4 { + return fmt.Errorf("invalid length for varbit: %v", len(src)) + } + + bitLen := int32(binary.BigEndian.Uint32(src)) + rp := 4 + + *dst = Varbit{Bytes: src[rp:], Len: bitLen, Status: Present} + return nil +} + +func (src Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + for i := int32(0); i < src.Len; i++ { + byteIdx := i / 8 + bitMask := byte(128 >> byte(i%8)) + char := byte('0') + if src.Bytes[byteIdx]&bitMask > 0 { + char = '1' + } + buf = append(buf, char) + } + + return buf, nil +} + +func (src Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + buf = pgio.AppendInt32(buf, src.Len) + return append(buf, src.Bytes...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Varbit) Scan(src interface{}) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Varbit) Value() (driver.Value, error) { + return EncodeValueText(src) +} diff --git a/vendor/github.com/jackc/pgtype/varchar.go b/vendor/github.com/jackc/pgtype/varchar.go new file mode 100644 index 000000000..fea31d181 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/varchar.go @@ -0,0 +1,66 @@ +package pgtype + +import ( + "database/sql/driver" +) + +type Varchar Text + +// Set converts from src to dst. Note that as Varchar is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *Varchar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst Varchar) Get() interface{} { + return (Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Varchar is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Varchar) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (Varchar) PreferredResultFormat() int16 { + return TextFormatCode +} + +func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (Varchar) PreferredParamFormat() int16 { + return TextFormatCode +} + +func (src Varchar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeText(ci, buf) +} + +func (src Varchar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (Text)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Varchar) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Varchar) Value() (driver.Value, error) { + return (Text)(src).Value() +} + +func (src Varchar) MarshalJSON() ([]byte, error) { + return (Text)(src).MarshalJSON() +} + +func (dst *Varchar) UnmarshalJSON(b []byte) error { + return (*Text)(dst).UnmarshalJSON(b) +} diff --git a/vendor/github.com/jackc/pgtype/varchar_array.go b/vendor/github.com/jackc/pgtype/varchar_array.go new file mode 100644 index 000000000..8a309a3f8 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/varchar_array.go @@ -0,0 +1,517 @@ +// Code generated by erb. DO NOT EDIT. + +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "reflect" + + "github.com/jackc/pgio" +) + +type VarcharArray struct { + Elements []Varchar + Dimensions []ArrayDimension + Status Status +} + +func (dst *VarcharArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + // Attempt to match to select common types: + switch value := src.(type) { + + case []string: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []*string: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []Varchar: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + *dst = VarcharArray{ + Elements: value, + Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}}, + Status: Present, + } + } + default: + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + reflectedValue := reflect.ValueOf(src) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + *dst = VarcharArray{Status: Null} + return nil + } + + dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0) + if !ok { + return fmt.Errorf("cannot find dimensions of %v for VarcharArray", src) + } + if elementsLength == 0 { + *dst = VarcharArray{Status: Present} + return nil + } + if len(dimensions) == 0 { + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to VarcharArray", src) + } + + *dst = VarcharArray{ + Elements: make([]Varchar, elementsLength), + Dimensions: dimensions, + Status: Present, + } + elementCount, err := dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + // Maybe the target was one dimension too far, try again: + if len(dst.Dimensions) > 1 { + dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1] + elementsLength = 0 + for _, dim := range dst.Dimensions { + if elementsLength == 0 { + elementsLength = int(dim.Length) + } else { + elementsLength *= int(dim.Length) + } + } + dst.Elements = make([]Varchar, elementsLength) + elementCount, err = dst.setRecursive(reflectedValue, 0, 0) + if err != nil { + return err + } + } else { + return err + } + } + if elementCount != len(dst.Elements) { + return fmt.Errorf("cannot convert %v to VarcharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount) + } + } + + return nil +} + +func (dst *VarcharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) { + switch value.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(dst.Dimensions) == dimension { + break + } + + valueLen := value.Len() + if int32(valueLen) != dst.Dimensions[dimension].Length { + return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions") + } + for i := 0; i < valueLen; i++ { + var err error + index, err = dst.setRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if !value.CanInterface() { + return 0, fmt.Errorf("cannot convert all values to VarcharArray") + } + if err := dst.Elements[index].Set(value.Interface()); err != nil { + return 0, fmt.Errorf("%v in VarcharArray", err) + } + index++ + + return index, nil +} + +func (dst VarcharArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *VarcharArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + if len(src.Dimensions) <= 1 { + // Attempt to match to select common types: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]*string: + *v = make([]*string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + } + } + + // Try to convert to something AssignTo can use directly. + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + + // Fallback to reflection if an optimised match was not found. + // The reflection is necessary for arrays and multidimensional slices, + // but it comes with a 20-50% performance penalty for large arrays/slices + value := reflect.ValueOf(dst) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + switch value.Kind() { + case reflect.Array, reflect.Slice: + default: + return fmt.Errorf("cannot assign %T to %T", src, dst) + } + + if len(src.Elements) == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + return nil + } + } + + elementCount, err := src.assignToRecursive(value, 0, 0) + if err != nil { + return err + } + if elementCount != len(src.Elements) { + return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount) + } + + return nil + case Null: + return NullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %#v into %T", src, dst) +} + +func (src *VarcharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) { + switch kind := value.Kind(); kind { + case reflect.Array: + fallthrough + case reflect.Slice: + if len(src.Dimensions) == dimension { + break + } + + length := int(src.Dimensions[dimension].Length) + if reflect.Array == kind { + typ := value.Type() + if typ.Len() != length { + return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len()) + } + value.Set(reflect.New(typ).Elem()) + } else { + value.Set(reflect.MakeSlice(value.Type(), length, length)) + } + + var err error + for i := 0; i < length; i++ { + index, err = src.assignToRecursive(value.Index(i), index, dimension+1) + if err != nil { + return 0, err + } + } + + return index, nil + } + if len(src.Dimensions) != dimension { + return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension) + } + if !value.CanAddr() { + return 0, fmt.Errorf("cannot assign all values from VarcharArray") + } + addr := value.Addr() + if !addr.CanInterface() { + return 0, fmt.Errorf("cannot assign all values from VarcharArray") + } + if err := src.Elements[index].AssignTo(addr.Interface()); err != nil { + return 0, err + } + index++ + return index, nil +} + +func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Varchar + + if len(uta.Elements) > 0 { + elements = make([]Varchar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Varchar + var elemSrc []byte + if s != "NULL" || uta.Quoted[i] { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Varchar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("varchar"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, fmt.Errorf("unable to find oid for type name %v", "varchar") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *VarcharArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src VarcharArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/vendor/github.com/jackc/pgtype/xid.go b/vendor/github.com/jackc/pgtype/xid.go new file mode 100644 index 000000000..f6d6b22d5 --- /dev/null +++ b/vendor/github.com/jackc/pgtype/xid.go @@ -0,0 +1,64 @@ +package pgtype + +import ( + "database/sql/driver" +) + +// XID is PostgreSQL's Transaction ID type. +// +// In later versions of PostgreSQL, it is the type used for the backend_xid +// and backend_xmin columns of the pg_stat_activity system view. +// +// Also, when one does +// +// select xmin, xmax, * from some_table; +// +// it is the data type of the xmin and xmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/postgres_ext.h as TransactionId +// in the PostgreSQL sources. +type XID pguint32 + +// Set converts from src to dst. Note that as XID is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *XID) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst XID) Get() interface{} { + return (pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as XID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *XID) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *XID) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) +} + +func (dst *XID) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) +} + +func (src XID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeText(ci, buf) +} + +func (src XID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (pguint32)(src).EncodeBinary(ci, buf) +} + +// Scan implements the database/sql Scanner interface. +func (dst *XID) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src XID) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/vendor/github.com/jackc/pgx/v4/.gitignore b/vendor/github.com/jackc/pgx/v4/.gitignore new file mode 100644 index 000000000..39175a965 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe + +.envrc diff --git a/vendor/github.com/jackc/pgx/v4/CHANGELOG.md b/vendor/github.com/jackc/pgx/v4/CHANGELOG.md new file mode 100644 index 000000000..ef4a2029a --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/CHANGELOG.md @@ -0,0 +1,209 @@ +# 4.13.0 (July 24, 2021) + +* Trimmed pseudo-dependencies in Go modules from other packages tests +* Upgrade pgconn -- context cancellation no longer will return a net.Error +* Support time durations for simple protocol (Michael Darr) + +# 4.12.0 (July 10, 2021) + +* ResetSession hook is called before a connection is reused from pool for another query (Dmytro Haranzha) +* stdlib: Add RandomizeHostOrderFunc (dkinder) +* stdlib: add OptionBeforeConnect (dkinder) +* stdlib: Do not reuse ConnConfig strings (Andrew Kimball) +* stdlib: implement Conn.ResetSession (Jonathan Amsterdam) +* Upgrade pgconn to v1.9.0 +* Upgrade pgtype to v1.8.0 + +# 4.11.0 (March 25, 2021) + +* Add BeforeConnect callback to pgxpool.Config (Robert Froehlich) +* Add Ping method to pgxpool.Conn (davidsbond) +* Added a kitlog level log adapter (Fabrice Aneche) +* Make ScanArgError public to allow identification of offending column (Pau Sanchez) +* Add *pgxpool.AcquireFunc +* Add BeginFunc and BeginTxFunc +* Add prefer_simple_protocol to connection string +* Add logging on CopyFrom (Patrick Hemmer) +* Add comment support when sanitizing SQL queries (Rusakow Andrew) +* Do not panic on double close of pgxpool.Pool (Matt Schultz) +* Avoid panic on SendBatch on closed Tx (Matt Schultz) +* Update pgconn to v1.8.1 +* Update pgtype to v1.7.0 + +# 4.10.1 (December 19, 2020) + +* Fix panic on Query error with nil stmtcache. + +# 4.10.0 (December 3, 2020) + +* Add CopyFromSlice to simplify CopyFrom usage (Egon Elbre) +* Remove broken prepared statements from stmtcache (Ethan Pailes) +* stdlib: consider any Ping error as fatal +* Update puddle to v1.1.3 - this fixes an issue where concurrent Acquires can hang when a connection cannot be established +* Update pgtype to v1.6.2 + +# 4.9.2 (November 3, 2020) + +The underlying library updates fix an issue where appending to a scanned slice could corrupt other data. + +* Update pgconn to v1.7.2 +* Update pgproto3 to v2.0.6 + +# 4.9.1 (October 31, 2020) + +* Update pgconn to v1.7.1 +* Update pgtype to v1.6.1 +* Fix SendBatch of all prepared statements with statement cache disabled + +# 4.9.0 (September 26, 2020) + +* pgxpool now waits for connection cleanup to finish before making room in pool for another connection. This prevents temporarily exceeding max pool size. +* Fix when scanning a column to nil to skip it on the first row but scanning it to a real value on a subsequent row. +* Fix prefer simple protocol with prepared statements. (Jinzhu) +* Fix FieldDescriptions not being available on Rows before calling Next the first time. +* Various minor fixes in updated versions of pgconn, pgtype, and puddle. + +# 4.8.1 (July 29, 2020) + +* Update pgconn to v1.6.4 + * Fix deadlock on error after CommandComplete but before ReadyForQuery + * Fix panic on parsing DSN with trailing '=' + +# 4.8.0 (July 22, 2020) + +* All argument types supported by native pgx should now also work through database/sql +* Update pgconn to v1.6.3 +* Update pgtype to v1.4.2 + +# 4.7.2 (July 14, 2020) + +* Improve performance of Columns() (zikaeroh) +* Fix fatal Commit() failure not being considered fatal +* Update pgconn to v1.6.2 +* Update pgtype to v1.4.1 + +# 4.7.1 (June 29, 2020) + +* Fix stdlib decoding error with certain order and combination of fields + +# 4.7.0 (June 27, 2020) + +* Update pgtype to v1.4.0 +* Update pgconn to v1.6.1 +* Update puddle to v1.1.1 +* Fix context propagation with Tx commit and Rollback (georgysavva) +* Add lazy connect option to pgxpool (georgysavva) +* Fix connection leak if pgxpool.BeginTx() fail (Jean-Baptiste Bronisz) +* Add native Go slice support for strings and numbers to simple protocol +* stdlib add default timeouts for Conn.Close() and Stmt.Close() (georgysavva) +* Assorted performance improvements especially with large result sets +* Fix close pool on not lazy connect failure (Yegor Myskin) +* Add Config copy (georgysavva) +* Support SendBatch with Simple Protocol (Jordan Lewis) +* Better error logging on rows close (Igor V. Kozinov) +* Expose stdlib.Conn.Conn() to enable database/sql.Conn.Raw() +* Improve unknown type support for database/sql +* Fix transaction commit failure closing connection + +# 4.6.0 (March 30, 2020) + +* stdlib: Bail early if preloading rows.Next() results in rows.Err() (Bas van Beek) +* Sanitize time to microsecond accuracy (Andrew Nicoll) +* Update pgtype to v1.3.0 +* Update pgconn to v1.5.0 + * Update golang.org/x/crypto for security fix + * Implement "verify-ca" SSL mode + +# 4.5.0 (March 7, 2020) + +* Update to pgconn v1.4.0 + * Fixes QueryRow with empty SQL + * Adds PostgreSQL service file support +* Add Len() to *pgx.Batch (WGH) +* Better logging for individual batch items (Ben Bader) + +# 4.4.1 (February 14, 2020) + +* Update pgconn to v1.3.2 - better default read buffer size +* Fix race in CopyFrom + +# 4.4.0 (February 5, 2020) + +* Update puddle to v1.1.0 - fixes possible deadlock when acquire is cancelled +* Update pgconn to v1.3.1 - fixes CopyFrom deadlock when multiple NoticeResponse received during copy +* Update pgtype to v1.2.0 +* Add MaxConnIdleTime to pgxpool (Patrick Ellul) +* Add MinConns to pgxpool (Patrick Ellul) +* Fix: stdlib.ReleaseConn closes connections left in invalid state + +# 4.3.0 (January 23, 2020) + +* Fix Rows.Values panic when unable to decode +* Add Rows.Values support for unknown types +* Add DriverContext support for stdlib (Alex Gaynor) +* Update pgproto3 to v2.0.1 to never return an io.EOF as it would be misinterpreted by database/sql. Instead return io.UnexpectedEOF. + +# 4.2.1 (January 13, 2020) + +* Update pgconn to v1.2.1 (fixes context cancellation data race introduced in v1.2.0)) + +# 4.2.0 (January 11, 2020) + +* Update pgconn to v1.2.0. +* Update pgtype to v1.1.0. +* Return error instead of panic when wrong number of arguments passed to Exec. (malstoun) +* Fix large objects functionality when PreferSimpleProtocol = true. +* Restore GetDefaultDriver which existed in v3. (Johan Brandhorst) +* Add RegisterConnConfig to stdlib which replaces the removed RegisterDriverConfig from v3. + +# 4.1.2 (October 22, 2019) + +* Fix dbSavepoint.Begin recursive self call +* Upgrade pgtype to v1.0.2 - fix scan pointer to pointer + +# 4.1.1 (October 21, 2019) + +* Fix pgxpool Rows.CommandTag() infinite loop / typo + +# 4.1.0 (October 12, 2019) + +## Potentially Breaking Changes + +Technically, two changes are breaking changes, but in practice these are extremely unlikely to break existing code. + +* Conn.Begin and Conn.BeginTx return a Tx interface instead of the internal dbTx struct. This is necessary for the Conn.Begin method to signature as other methods that begin a transaction. +* Add Conn() to Tx interface. This is necessary to allow code using a Tx to access the *Conn (and pgconn.PgConn) on which the Tx is executing. + +## Fixes + +* Releasing a busy connection closes the connection instead of returning an unusable connection to the pool +* Do not mutate config.Config.OnNotification in connect + +# 4.0.1 (September 19, 2019) + +* Fix statement cache cleanup. +* Corrected daterange OID. +* Fix Tx when committing or rolling back multiple times in certain cases. +* Improve documentation. + +# 4.0.0 (September 14, 2019) + +v4 is a major release with many significant changes some of which are breaking changes. The most significant are +included below. + +* Simplified establishing a connection with a connection string. +* All potentially blocking operations now require a context.Context. The non-context aware functions have been removed. +* OIDs are hard-coded for known types. This saves the query on connection. +* Context cancellations while network activity is in progress is now always fatal. Previously, it was sometimes recoverable. This led to increased complexity in pgx itself and in application code. +* Go modules are required. +* Errors are now implemented in the Go 1.13 style. +* `Rows` and `Tx` are now interfaces. +* The connection pool as been decoupled from pgx and is now a separate, included package (github.com/jackc/pgx/v4/pgxpool). +* pgtype has been spun off to a separate package (github.com/jackc/pgtype). +* pgproto3 has been spun off to a separate package (github.com/jackc/pgproto3/v2). +* Logical replication support has been spun off to a separate package (github.com/jackc/pglogrepl). +* Lower level PostgreSQL functionality is now implemented in a separate package (github.com/jackc/pgconn). +* Tests are now configured with environment variables. +* Conn has an automatic statement cache by default. +* Batch interface has been simplified. +* QueryArgs has been removed. diff --git a/vendor/github.com/jackc/pgx/v4/LICENSE b/vendor/github.com/jackc/pgx/v4/LICENSE new file mode 100644 index 000000000..5c486c39a --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2013-2021 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/jackc/pgx/v4/README.md b/vendor/github.com/jackc/pgx/v4/README.md new file mode 100644 index 000000000..732320447 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/README.md @@ -0,0 +1,203 @@ +[](https://pkg.go.dev/github.com/jackc/pgx/v4) +[](https://travis-ci.org/jackc/pgx) + +# pgx - PostgreSQL Driver and Toolkit + +pgx is a pure Go driver and toolkit for PostgreSQL. + +pgx aims to be low-level, fast, and performant, while also enabling PostgreSQL-specific features that the standard `database/sql` package does not allow for. + +The driver component of pgx can be used alongside the standard `database/sql` package. + +The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol +and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers, +proxies, load balancers, logical replication clients, etc. + +The current release of `pgx v4` requires Go modules. To use the previous version, checkout and vendor the `v3` branch. + +## Example Usage + +```go +package main + +import ( + "context" + "fmt" + "os" + + "github.com/jackc/pgx/v4" +) + +func main() { + // urlExample := "postgres://username:password@localhost:5432/database_name" + conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err) + os.Exit(1) + } + defer conn.Close(context.Background()) + + var name string + var weight int64 + err = conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight) + if err != nil { + fmt.Fprintf(os.Stderr, "QueryRow failed: %v\n", err) + os.Exit(1) + } + + fmt.Println(name, weight) +} +``` + +See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information. + +## Choosing Between the pgx and database/sql Interfaces + +It is recommended to use the pgx interface if: +1. The application only targets PostgreSQL. +2. No other libraries that require `database/sql` are in use. + +The pgx interface is faster and exposes more features. + +The `database/sql` interface only allows the underlying driver to return or receive the following types: `int64`, +`float64`, `bool`, `[]byte`, `string`, `time.Time`, or `nil`. Handling other types requires implementing the +`database/sql.Scanner` and the `database/sql/driver/driver.Valuer` interfaces which require transmission of values in text format. The binary format can be substantially faster, which is what the pgx interface uses. + +## Features + +pgx supports many features beyond what is available through `database/sql`: + +* Support for approximately 70 different PostgreSQL types +* Automatic statement preparation and caching +* Batch queries +* Single-round trip query mode +* Full TLS connection control +* Binary format support for custom types (allows for much quicker encoding/decoding) +* Copy protocol support for faster bulk data loads +* Extendable logging support including built-in support for `log15adapter`, [`logrus`](https://github.com/sirupsen/logrus), [`zap`](https://github.com/uber-go/zap), and [`zerolog`](https://github.com/rs/zerolog) +* Connection pool with after-connect hook for arbitrary connection setup +* Listen / notify +* Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings +* Hstore support +* JSON and JSONB support +* Maps `inet` and `cidr` PostgreSQL types to `net.IPNet` and `net.IP` +* Large object support +* NULL mapping to Null* struct or pointer to pointer +* Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types +* Notice response handling +* Simulated nested transactions with savepoints + +## Performance + +There are three areas in particular where pgx can provide a significant performance advantage over the standard +`database/sql` interface and other drivers: + +1. PostgreSQL specific types - Types such as arrays can be parsed much quicker because pgx uses the binary format. +2. Automatic statement preparation and caching - pgx will prepare and cache statements by default. This can provide an + significant free improvement to code that does not explicitly use prepared statements. Under certain workloads, it can + perform nearly 3x the number of queries per second. +3. Batched queries - Multiple queries can be batched together to minimize network round trips. + +## Comparison with Alternatives + +* [pq](http://godoc.org/github.com/lib/pq) +* [go-pg](https://github.com/go-pg/pg) + +For prepared queries with small sets of simple data types, all drivers will have have similar performance. However, if prepared statements aren't being explicitly used, pgx can have a significant performance advantage due to automatic statement preparation. +pgx also can perform better when using PostgreSQL-specific data types or query batching. See +[go_db_bench](https://github.com/jackc/go_db_bench) for some database driver benchmarks. + +### Compatibility with `database/sql` + +pq is exclusively used with `database/sql`. go-pg does not use `database/sql` at all. pgx supports `database/sql` as well as +its own interface. + +### Level of access, ORM + +go-pg is a PostgreSQL client and ORM. It includes many features that traditionally sit above the database driver, such as ORM, struct mapping, soft deletes, schema migrations, and sharding support. + +pgx is "closer to the metal" and such abstractions are beyond the scope of the pgx project, which first and foremost, aims to be a performant driver and toolkit. + +## Testing + +pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE` environment +variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or DSN. In addition, the standard `PG*` environment +variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify environment variable +handling. + +### Example Test Environment + +Connect to your PostgreSQL server and run: + +``` +create database pgx_test; +``` + +Connect to the newly-created database and run: + +``` +create domain uint64 as numeric(20,0); +``` + +Now, you can run the tests: + +``` +PGX_TEST_DATABASE="host=/var/run/postgresql database=pgx_test" go test ./... +``` + +In addition, there are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set. + +## Supported Go and PostgreSQL Versions + +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.15 and higher and PostgreSQL 9.6 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). + +## Version Policy + +pgx follows semantic versioning for the documented public API on stable releases. `v4` is the latest stable major version. + +## PGX Family Libraries + +pgx is the head of a family of PostgreSQL libraries. Many of these can be used independently. Many can also be accessed +from pgx for lower-level control. + +### [github.com/jackc/pgconn](https://github.com/jackc/pgconn) + +`pgconn` is a lower-level PostgreSQL database driver that operates at nearly the same level as the C library `libpq`. + +### [github.com/jackc/pgx/v4/pgxpool](https://github.com/jackc/pgx/tree/master/pgxpool) + +`pgxpool` is a connection pool for pgx. pgx is entirely decoupled from its default pool implementation. This means that pgx can be used with a different pool or without any pool at all. + +### [github.com/jackc/pgx/v4/stdlib](https://github.com/jackc/pgx/tree/master/stdlib) + +This is a `database/sql` compatibility layer for pgx. pgx can be used as a normal `database/sql` driver, but at any time, the native interface can be acquired for more performance or PostgreSQL specific functionality. + +### [github.com/jackc/pgtype](https://github.com/jackc/pgtype) + +Over 70 PostgreSQL types are supported including `uuid`, `hstore`, `json`, `bytea`, `numeric`, `interval`, `inet`, and arrays. These types support `database/sql` interfaces and are usable outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver. + +### [github.com/jackc/pgproto3](https://github.com/jackc/pgproto3) + +pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling. + +### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl) + +pglogrepl provides functionality to act as a client for PostgreSQL logical replication. + +### [github.com/jackc/pgmock](https://github.com/jackc/pgmock) + +pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler). + +### [github.com/jackc/tern](https://github.com/jackc/tern) + +tern is a stand-alone SQL migration system. + +### [github.com/jackc/pgerrcode](https://github.com/jackc/pgerrcode) + +pgerrcode contains constants for the PostgreSQL error codes. + +## 3rd Party Libraries with PGX Support + +### [github.com/georgysavva/scany](https://github.com/georgysavva/scany) + +Library for scanning data from a database into Go structs and more. diff --git a/vendor/github.com/jackc/pgx/v4/batch.go b/vendor/github.com/jackc/pgx/v4/batch.go new file mode 100644 index 000000000..4b96ca194 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/batch.go @@ -0,0 +1,179 @@ +package pgx + +import ( + "context" + "errors" + + "github.com/jackc/pgconn" +) + +type batchItem struct { + query string + arguments []interface{} +} + +// Batch queries are a way of bundling multiple queries together to avoid +// unnecessary network round trips. +type Batch struct { + items []*batchItem +} + +// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. +func (b *Batch) Queue(query string, arguments ...interface{}) { + b.items = append(b.items, &batchItem{ + query: query, + arguments: arguments, + }) +} + +// Len returns number of queries that have been queued so far. +func (b *Batch) Len() int { + return len(b.items) +} + +type BatchResults interface { + // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. + Exec() (pgconn.CommandTag, error) + + // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. + Query() (Rows, error) + + // QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow. + QueryRow() Row + + // Close closes the batch operation. This must be called before the underlying connection can be used again. Any error + // that occurred during a batch operation may have made it impossible to resyncronize the connection with the server. + // In this case the underlying connection will have been closed. + Close() error +} + +type batchResults struct { + ctx context.Context + conn *Conn + mrr *pgconn.MultiResultReader + err error + b *Batch + ix int +} + +// Exec reads the results from the next query in the batch as if the query has been sent with Exec. +func (br *batchResults) Exec() (pgconn.CommandTag, error) { + if br.err != nil { + return nil, br.err + } + + query, arguments, _ := br.nextQueryAndArgs() + + if !br.mrr.NextResult() { + err := br.mrr.Close() + if err == nil { + err = errors.New("no result") + } + if br.conn.shouldLog(LogLevelError) { + br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ + "sql": query, + "args": logQueryArgs(arguments), + "err": err, + }) + } + return nil, err + } + + commandTag, err := br.mrr.ResultReader().Close() + + if err != nil { + if br.conn.shouldLog(LogLevelError) { + br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ + "sql": query, + "args": logQueryArgs(arguments), + "err": err, + }) + } + } else if br.conn.shouldLog(LogLevelInfo) { + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{}{ + "sql": query, + "args": logQueryArgs(arguments), + "commandTag": commandTag, + }) + } + + return commandTag, err +} + +// Query reads the results from the next query in the batch as if the query has been sent with Query. +func (br *batchResults) Query() (Rows, error) { + query, arguments, ok := br.nextQueryAndArgs() + if !ok { + query = "batch query" + } + + if br.err != nil { + return &connRows{err: br.err, closed: true}, br.err + } + + rows := br.conn.getRows(br.ctx, query, arguments) + + if !br.mrr.NextResult() { + rows.err = br.mrr.Close() + if rows.err == nil { + rows.err = errors.New("no result") + } + rows.closed = true + + if br.conn.shouldLog(LogLevelError) { + br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{}{ + "sql": query, + "args": logQueryArgs(arguments), + "err": rows.err, + }) + } + + return rows, rows.err + } + + rows.resultReader = br.mrr.ResultReader() + return rows, nil +} + +// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. +func (br *batchResults) QueryRow() Row { + rows, _ := br.Query() + return (*connRow)(rows.(*connRows)) + +} + +// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to +// resyncronize the connection with the server. In this case the underlying connection will have been closed. +func (br *batchResults) Close() error { + if br.err != nil { + return br.err + } + + // log any queries that haven't yet been logged by Exec or Query + for { + query, args, ok := br.nextQueryAndArgs() + if !ok { + break + } + + if br.conn.shouldLog(LogLevelInfo) { + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{}{ + "sql": query, + "args": logQueryArgs(args), + }) + } + } + + return br.mrr.Close() +} + +func (br *batchResults) nextQueryAndArgs() (query string, args []interface{}, ok bool) { + if br.b != nil && br.ix < len(br.b.items) { + bi := br.b.items[br.ix] + query = bi.query + args = bi.arguments + ok = true + br.ix++ + } + return +} diff --git a/vendor/github.com/jackc/pgx/v4/conn.go b/vendor/github.com/jackc/pgx/v4/conn.go new file mode 100644 index 000000000..9636f2fd6 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/conn.go @@ -0,0 +1,850 @@ +package pgx + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgconn/stmtcache" + "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4/internal/sanitize" +) + +// ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and +// then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic. +type ConnConfig struct { + pgconn.Config + Logger Logger + LogLevel LogLevel + + // Original connection string that was parsed into config. + connString string + + // BuildStatementCache creates the stmtcache.Cache implementation for connections created with this config. Set + // to nil to disable automatic prepared statements. + BuildStatementCache BuildStatementCacheFunc + + // PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended + // protocol. This can improve performance due to being able to use the binary format. It also does not rely on client + // side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement) + // and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be + // used by default. The same functionality can be controlled on a per query basis by setting + // QueryExOptions.SimpleProtocol. + PreferSimpleProtocol bool + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. +} + +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the tls.Config: +// according to the tls.Config docs it must not be modified after creation. +func (cc *ConnConfig) Copy() *ConnConfig { + newConfig := new(ConnConfig) + *newConfig = *cc + newConfig.Config = *newConfig.Config.Copy() + return newConfig +} + +func (cc *ConnConfig) ConnString() string { return cc.connString } + +// BuildStatementCacheFunc is a function that can be used to create a stmtcache.Cache implementation for connection. +type BuildStatementCacheFunc func(conn *pgconn.PgConn) stmtcache.Cache + +// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access +// to multiple database connections from multiple goroutines. +type Conn struct { + pgConn *pgconn.PgConn + config *ConnConfig // config used when establishing this connection + preparedStatements map[string]*pgconn.StatementDescription + stmtcache stmtcache.Cache + logger Logger + logLevel LogLevel + + notifications []*pgconn.Notification + + doneChan chan struct{} + closedChan chan error + + connInfo *pgtype.ConnInfo + + wbuf []byte + preallocatedRows []connRows + eqb extendedQueryBuilder +} + +// Identifier a PostgreSQL identifier or name. Identifiers can be composed of +// multiple parts such as ["schema", "table"] or ["table", "column"]. +type Identifier []string + +// Sanitize returns a sanitized string safe for SQL interpolation. +func (ident Identifier) Sanitize() string { + parts := make([]string, len(ident)) + for i := range ident { + s := strings.ReplaceAll(ident[i], string([]byte{0}), "") + parts[i] = `"` + strings.ReplaceAll(s, `"`, `""`) + `"` + } + return strings.Join(parts, ".") +} + +// ErrNoRows occurs when rows are expected but none are returned. +var ErrNoRows = errors.New("no rows in result set") + +// ErrInvalidLogLevel occurs on attempt to set an invalid log level. +var ErrInvalidLogLevel = errors.New("invalid log level") + +// Connect establishes a connection with a PostgreSQL server with a connection string. See +// pgconn.Connect for details. +func Connect(ctx context.Context, connString string) (*Conn, error) { + connConfig, err := ParseConfig(connString) + if err != nil { + return nil, err + } + return connect(ctx, connConfig) +} + +// Connect establishes a connection with a PostgreSQL server with a configuration struct. connConfig must have been +// created by ParseConfig. +func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { + return connect(ctx, connConfig) +} + +// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig +// does. In addition, it accepts the following options: +// +// statement_cache_capacity +// The maximum size of the automatic statement cache. Set to 0 to disable automatic statement caching. Default: 512. +// +// statement_cache_mode +// Possible values: "prepare" and "describe". "prepare" will create prepared statements on the PostgreSQL server. +// "describe" will use the anonymous prepared statement to describe a statement without creating a statement on the +// server. "describe" is primarily useful when the environment does not allow prepared statements such as when +// running a connection pooler like PgBouncer. Default: "prepare" +// +// prefer_simple_protocol +// Possible values: "true" and "false". Use the simple protocol instead of extended protocol. Default: false +func ParseConfig(connString string) (*ConnConfig, error) { + config, err := pgconn.ParseConfig(connString) + if err != nil { + return nil, err + } + + var buildStatementCache BuildStatementCacheFunc + statementCacheCapacity := 512 + statementCacheMode := stmtcache.ModePrepare + if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok { + delete(config.RuntimeParams, "statement_cache_capacity") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err) + } + statementCacheCapacity = int(n) + } + + if s, ok := config.RuntimeParams["statement_cache_mode"]; ok { + delete(config.RuntimeParams, "statement_cache_mode") + switch s { + case "prepare": + statementCacheMode = stmtcache.ModePrepare + case "describe": + statementCacheMode = stmtcache.ModeDescribe + default: + return nil, fmt.Errorf("invalid statement_cache_mod: %s", s) + } + } + + if statementCacheCapacity > 0 { + buildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, statementCacheMode, statementCacheCapacity) + } + } + + preferSimpleProtocol := false + if s, ok := config.RuntimeParams["prefer_simple_protocol"]; ok { + delete(config.RuntimeParams, "prefer_simple_protocol") + if b, err := strconv.ParseBool(s); err == nil { + preferSimpleProtocol = b + } else { + return nil, fmt.Errorf("invalid prefer_simple_protocol: %v", err) + } + } + + connConfig := &ConnConfig{ + Config: *config, + createdByParseConfig: true, + LogLevel: LogLevelInfo, + BuildStatementCache: buildStatementCache, + PreferSimpleProtocol: preferSimpleProtocol, + connString: connString, + } + + return connConfig, nil +} + +func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") + } + originalConfig := config + + // This isn't really a deep copy. But it is enough to avoid the config.Config.OnNotification mutation from affecting + // other connections with the same config. See https://github.com/jackc/pgx/issues/618. + { + configCopy := *config + config = &configCopy + } + + c = &Conn{ + config: originalConfig, + connInfo: pgtype.NewConnInfo(), + logLevel: config.LogLevel, + logger: config.Logger, + } + + // Only install pgx notification system if no other callback handler is present. + if config.Config.OnNotification == nil { + config.Config.OnNotification = c.bufferNotifications + } else { + if c.shouldLog(LogLevelDebug) { + c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]interface{}{"host": config.Config.Host}) + } + } + + if c.shouldLog(LogLevelInfo) { + c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host}) + } + c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err}) + } + return nil, err + } + + c.preparedStatements = make(map[string]*pgconn.StatementDescription) + c.doneChan = make(chan struct{}) + c.closedChan = make(chan error) + c.wbuf = make([]byte, 0, 1024) + + if c.config.BuildStatementCache != nil { + c.stmtcache = c.config.BuildStatementCache(c.pgConn) + } + + // Replication connections can't execute the queries to + // populate the c.PgTypes and c.pgsqlAfInet + if _, ok := config.Config.RuntimeParams["replication"]; ok { + return c, nil + } + + return c, nil +} + +// Close closes a connection. It is safe to call Close on a already closed +// connection. +func (c *Conn) Close(ctx context.Context) error { + if c.IsClosed() { + return nil + } + + err := c.pgConn.Close(ctx) + if c.shouldLog(LogLevelInfo) { + c.log(ctx, LogLevelInfo, "closed connection", nil) + } + return err +} + +// Prepare creates a prepared statement with name and sql. sql can contain placeholders +// for bound parameters. These placeholders are referenced positional as $1, $2, etc. +// +// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same +// name and sql arguments. This allows a code path to Prepare and Query/Exec without +// concern for if the statement has already been prepared. +func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { + if name != "" { + var ok bool + if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql { + return sd, nil + } + } + + if c.shouldLog(LogLevelError) { + defer func() { + if err != nil { + c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) + } + }() + } + + sd, err = c.pgConn.Prepare(ctx, name, sql, nil) + if err != nil { + return nil, err + } + + if name != "" { + c.preparedStatements[name] = sd + } + + return sd, nil +} + +// Deallocate released a prepared statement +func (c *Conn) Deallocate(ctx context.Context, name string) error { + delete(c.preparedStatements, name) + _, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll() + return err +} + +func (c *Conn) bufferNotifications(_ *pgconn.PgConn, n *pgconn.Notification) { + c.notifications = append(c.notifications, n) +} + +// WaitForNotification waits for a PostgreSQL notification. It wraps the underlying pgconn notification system in a +// slightly more convenient form. +func (c *Conn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) { + var n *pgconn.Notification + + // Return already received notification immediately + if len(c.notifications) > 0 { + n = c.notifications[0] + c.notifications = c.notifications[1:] + return n, nil + } + + err := c.pgConn.WaitForNotification(ctx) + if len(c.notifications) > 0 { + n = c.notifications[0] + c.notifications = c.notifications[1:] + } + return n, err +} + +func (c *Conn) IsClosed() bool { + return c.pgConn.IsClosed() +} + +func (c *Conn) die(err error) { + if c.IsClosed() { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // force immediate hard cancel + c.pgConn.Close(ctx) +} + +func (c *Conn) shouldLog(lvl LogLevel) bool { + return c.logger != nil && c.logLevel >= lvl +} + +func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) { + if data == nil { + data = map[string]interface{}{} + } + if c.pgConn != nil && c.pgConn.PID() != 0 { + data["pid"] = c.pgConn.PID() + } + + c.logger.Log(ctx, lvl, msg, data) +} + +func quoteIdentifier(s string) string { + return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` +} + +func (c *Conn) Ping(ctx context.Context) error { + _, err := c.Exec(ctx, ";") + return err +} + +func connInfoFromRows(rows Rows, err error) (map[string]uint32, error) { + if err != nil { + return nil, err + } + defer rows.Close() + + nameOIDs := make(map[string]uint32, 256) + for rows.Next() { + var oid uint32 + var name pgtype.Text + if err = rows.Scan(&oid, &name); err != nil { + return nil, err + } + + nameOIDs[name.String] = oid + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return nameOIDs, err +} + +// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the +// PostgreSQL connection than pgx exposes. +// +// It is strongly recommended that the connection be idle (no in-progress queries) before the underlying *pgconn.PgConn +// is used and the connection must be returned to the same state before any *pgx.Conn methods are again used. +func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } + +// StatementCache returns the statement cache used for this connection. +func (c *Conn) StatementCache() stmtcache.Cache { return c.stmtcache } + +// ConnInfo returns the connection info used for this connection. +func (c *Conn) ConnInfo() *pgtype.ConnInfo { return c.connInfo } + +// Config returns a copy of config that was used to establish this connection. +func (c *Conn) Config() *ConnConfig { return c.config.Copy() } + +// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced +// positionally from the sql string as $1, $2, etc. +func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { + startTime := time.Now() + + commandTag, err := c.exec(ctx, sql, arguments...) + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) + } + return commandTag, err + } + + if c.shouldLog(LogLevelInfo) { + endTime := time.Now() + c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + } + + return commandTag, err +} + +func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { + simpleProtocol := c.config.PreferSimpleProtocol + +optionLoop: + for len(arguments) > 0 { + switch arg := arguments[0].(type) { + case QuerySimpleProtocol: + simpleProtocol = bool(arg) + arguments = arguments[1:] + default: + break optionLoop + } + } + + if sd, ok := c.preparedStatements[sql]; ok { + return c.execPrepared(ctx, sd, arguments) + } + + if simpleProtocol { + return c.execSimpleProtocol(ctx, sql, arguments) + } + + if len(arguments) == 0 { + return c.execSimpleProtocol(ctx, sql, arguments) + } + + if c.stmtcache != nil { + sd, err := c.stmtcache.Get(ctx, sql) + if err != nil { + return nil, err + } + + if c.stmtcache.Mode() == stmtcache.ModeDescribe { + return c.execParams(ctx, sd, arguments) + } + return c.execPrepared(ctx, sd, arguments) + } + + sd, err := c.Prepare(ctx, "", sql) + if err != nil { + return nil, err + } + return c.execPrepared(ctx, sd, arguments) +} + +func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { + if len(arguments) > 0 { + sql, err = c.sanitizeForSimpleQuery(sql, arguments...) + if err != nil { + return nil, err + } + } + + mrr := c.pgConn.Exec(ctx, sql) + for mrr.NextResult() { + commandTag, err = mrr.ResultReader().Close() + } + err = mrr.Close() + return commandTag, err +} + +func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, arguments []interface{}) error { + if len(sd.ParamOIDs) != len(arguments) { + return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(arguments)) + } + + c.eqb.Reset() + + args, err := convertDriverValuers(arguments) + if err != nil { + return err + } + + for i := range args { + err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) + if err != nil { + return err + } + } + + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + return nil +} + +func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { + err := c.execParamsAndPreparedPrefix(sd, arguments) + if err != nil { + return nil, err + } + + result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() + return result.CommandTag, result.Err +} + +func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { + err := c.execParamsAndPreparedPrefix(sd, arguments) + if err != nil { + return nil, err + } + + result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() + return result.CommandTag, result.Err +} + +func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows { + if len(c.preallocatedRows) == 0 { + c.preallocatedRows = make([]connRows, 64) + } + + r := &c.preallocatedRows[len(c.preallocatedRows)-1] + c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] + + r.ctx = ctx + r.logger = c + r.connInfo = c.connInfo + r.startTime = time.Now() + r.sql = sql + r.args = args + r.conn = c + + return r +} + +// QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query. +type QuerySimpleProtocol bool + +// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position. +type QueryResultFormats []int16 + +// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID. +type QueryResultFormatsByOID map[uint32]int16 + +// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is +// allowed to ignore the error returned from Query and handle it in Rows. +// +// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and +// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely +// needed. See the documentation for those types for details. +func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { + var resultFormats QueryResultFormats + var resultFormatsByOID QueryResultFormatsByOID + simpleProtocol := c.config.PreferSimpleProtocol + +optionLoop: + for len(args) > 0 { + switch arg := args[0].(type) { + case QueryResultFormats: + resultFormats = arg + args = args[1:] + case QueryResultFormatsByOID: + resultFormatsByOID = arg + args = args[1:] + case QuerySimpleProtocol: + simpleProtocol = bool(arg) + args = args[1:] + default: + break optionLoop + } + } + + rows := c.getRows(ctx, sql, args) + + var err error + sd, ok := c.preparedStatements[sql] + + if simpleProtocol && !ok { + sql, err = c.sanitizeForSimpleQuery(sql, args...) + if err != nil { + rows.fatal(err) + return rows, err + } + + mrr := c.pgConn.Exec(ctx, sql) + if mrr.NextResult() { + rows.resultReader = mrr.ResultReader() + rows.multiResultReader = mrr + } else { + err = mrr.Close() + rows.fatal(err) + return rows, err + } + + return rows, nil + } + + c.eqb.Reset() + + if !ok { + if c.stmtcache != nil { + sd, err = c.stmtcache.Get(ctx, sql) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + } else { + sd, err = c.pgConn.Prepare(ctx, "", sql, nil) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + } + } + if len(sd.ParamOIDs) != len(args) { + rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) + return rows, rows.err + } + + rows.sql = sd.SQL + + args, err = convertDriverValuers(args) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + + for i := range args { + err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + } + + if resultFormatsByOID != nil { + resultFormats = make([]int16, len(sd.Fields)) + for i := range resultFormats { + resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)] + } + } + + if resultFormats == nil { + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + resultFormats = c.eqb.resultFormats + } + + if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe { + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) + } else { + rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) + } + + return rows, rows.err +} + +// QueryRow is a convenience wrapper over Query. Any error that occurs while +// querying is deferred until calling Scan on the returned Row. That Row will +// error with ErrNoRows if no rows are returned. +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { + rows, _ := c.Query(ctx, sql, args...) + return (*connRow)(rows.(*connRows)) +} + +// QueryFuncRow is the argument to the QueryFunc callback function. +// +// QueryFuncRow is an interface instead of a struct to allow tests to mock QueryFunc. However, adding a method to an +// interface is technically a breaking change. Because of this the QueryFuncRow interface is partially excluded from +// semantic version requirements. Methods will not be removed or changed, but new methods may be added. +type QueryFuncRow interface { + FieldDescriptions() []pgproto3.FieldDescription + + // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid during the current + // function call. However, the underlying byte data is safe to retain a reference to and mutate. + RawValues() [][]byte +} + +// QueryFunc executes sql with args. For each row returned by the query the values will scanned into the elements of +// scans and f will be called. If any row fails to scan or f returns an error the query will be aborted and the error +// will be returned. +func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { + rows, err := c.Query(ctx, sql, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + err = rows.Scan(scans...) + if err != nil { + return nil, err + } + + err = f(rows) + if err != nil { + return nil, err + } + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return rows.CommandTag(), nil +} + +// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless +// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection +// is used again. +func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { + simpleProtocol := c.config.PreferSimpleProtocol + var sb strings.Builder + if simpleProtocol { + for i, bi := range b.items { + if i > 0 { + sb.WriteByte(';') + } + sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + sb.WriteString(sql) + } + mrr := c.pgConn.Exec(ctx, sb.String()) + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + ix: 0, + } + } + + distinctUnpreparedQueries := map[string]struct{}{} + + for _, bi := range b.items { + if _, ok := c.preparedStatements[bi.query]; ok { + continue + } + distinctUnpreparedQueries[bi.query] = struct{}{} + } + + var stmtCache stmtcache.Cache + if len(distinctUnpreparedQueries) > 0 { + if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) { + stmtCache = c.stmtcache + } else { + stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) + } + + for sql, _ := range distinctUnpreparedQueries { + _, err := stmtCache.Get(ctx, sql) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + } + } + + batch := &pgconn.Batch{} + + for _, bi := range b.items { + c.eqb.Reset() + + sd := c.preparedStatements[bi.query] + if sd == nil { + var err error + sd, err = stmtCache.Get(ctx, bi.query) + if err != nil { + // the stmtCache was prefilled from distinctUnpreparedQueries above so we are guaranteed no errors + panic("BUG: unexpected error from stmtCache") + } + } + + if len(sd.ParamOIDs) != len(bi.arguments) { + return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} + } + + args, err := convertDriverValuers(bi.arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + + for i := range args { + err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + } + + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + if sd.Name == "" { + batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats) + } else { + batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) + } + } + + mrr := c.pgConn.ExecBatch(ctx, batch) + + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + ix: 0, + } +} + +func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) { + if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { + return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") + } + + if c.pgConn.ParameterStatus("client_encoding") != "UTF8" { + return "", errors.New("simple protocol queries must be run with client_encoding=UTF8") + } + + var err error + valueArgs := make([]interface{}, len(args)) + for i, a := range args { + valueArgs[i], err = convertSimpleArgument(c.connInfo, a) + if err != nil { + return "", err + } + } + + return sanitize.SanitizeSQL(sql, valueArgs...) +} diff --git a/vendor/github.com/jackc/pgx/v4/copy_from.go b/vendor/github.com/jackc/pgx/v4/copy_from.go new file mode 100644 index 000000000..3494e28f9 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/copy_from.go @@ -0,0 +1,211 @@ +package pgx + +import ( + "bytes" + "context" + "fmt" + "io" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgio" +) + +// CopyFromRows returns a CopyFromSource interface over the provided rows slice +// making it usable by *Conn.CopyFrom. +func CopyFromRows(rows [][]interface{}) CopyFromSource { + return ©FromRows{rows: rows, idx: -1} +} + +type copyFromRows struct { + rows [][]interface{} + idx int +} + +func (ctr *copyFromRows) Next() bool { + ctr.idx++ + return ctr.idx < len(ctr.rows) +} + +func (ctr *copyFromRows) Values() ([]interface{}, error) { + return ctr.rows[ctr.idx], nil +} + +func (ctr *copyFromRows) Err() error { + return nil +} + +// CopyFromSlice returns a CopyFromSource interface over a dynamic func +// making it usable by *Conn.CopyFrom. +func CopyFromSlice(length int, next func(int) ([]interface{}, error)) CopyFromSource { + return ©FromSlice{next: next, idx: -1, len: length} +} + +type copyFromSlice struct { + next func(int) ([]interface{}, error) + idx int + len int + err error +} + +func (cts *copyFromSlice) Next() bool { + cts.idx++ + return cts.idx < cts.len +} + +func (cts *copyFromSlice) Values() ([]interface{}, error) { + values, err := cts.next(cts.idx) + if err != nil { + cts.err = err + } + return values, err +} + +func (cts *copyFromSlice) Err() error { + return cts.err +} + +// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. +type CopyFromSource interface { + // Next returns true if there is another row and makes the next row data + // available to Values(). When there are no more rows available or an error + // has occurred it returns false. + Next() bool + + // Values returns the values for the current row. + Values() ([]interface{}, error) + + // Err returns any error that has been encountered by the CopyFromSource. If + // this is not nil *Conn.CopyFrom will abort the copy. + Err() error +} + +type copyFrom struct { + conn *Conn + tableName Identifier + columnNames []string + rowSrc CopyFromSource + readerErrChan chan error +} + +func (ct *copyFrom) run(ctx context.Context) (int64, error) { + quotedTableName := ct.tableName.Sanitize() + cbuf := &bytes.Buffer{} + for i, cn := range ct.columnNames { + if i != 0 { + cbuf.WriteString(", ") + } + cbuf.WriteString(quoteIdentifier(cn)) + } + quotedColumnNames := cbuf.String() + + sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) + if err != nil { + return 0, err + } + + r, w := io.Pipe() + doneChan := make(chan struct{}) + + go func() { + defer close(doneChan) + + // Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283. + buf := ct.conn.wbuf + + buf = append(buf, "PGCOPY\n\377\r\n\000"...) + buf = pgio.AppendInt32(buf, 0) + buf = pgio.AppendInt32(buf, 0) + + moreRows := true + for moreRows { + var err error + moreRows, buf, err = ct.buildCopyBuf(buf, sd) + if err != nil { + w.CloseWithError(err) + return + } + + if ct.rowSrc.Err() != nil { + w.CloseWithError(ct.rowSrc.Err()) + return + } + + if len(buf) > 0 { + _, err = w.Write(buf) + if err != nil { + w.Close() + return + } + } + + buf = buf[:0] + } + + w.Close() + }() + + startTime := time.Now() + + commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + + r.Close() + <-doneChan + + rowsAffected := commandTag.RowsAffected() + if err == nil { + if ct.conn.shouldLog(LogLevelInfo) { + endTime := time.Now() + ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]interface{}{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected}) + } + } else if ct.conn.shouldLog(LogLevelError) { + ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames}) + } + + return rowsAffected, err +} + +func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { + + for ct.rowSrc.Next() { + values, err := ct.rowSrc.Values() + if err != nil { + return false, nil, err + } + if len(values) != len(ct.columnNames) { + return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + } + + buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) + for i, val := range values { + buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val) + if err != nil { + return false, nil, err + } + } + + if len(buf) > 65536 { + return true, buf, nil + } + } + + return false, buf, nil +} + +// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. +// It returns the number of rows copied and an error. +// +// CopyFrom requires all values use the binary format. Almost all types +// implemented by pgx use the binary format by default. Types implementing +// Encoder can only be used if they encode to the binary format. +func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { + ct := ©From{ + conn: c, + tableName: tableName, + columnNames: columnNames, + rowSrc: rowSrc, + readerErrChan: make(chan error), + } + + return ct.run(ctx) +} diff --git a/vendor/github.com/jackc/pgx/v4/doc.go b/vendor/github.com/jackc/pgx/v4/doc.go new file mode 100644 index 000000000..51b0d9f44 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/doc.go @@ -0,0 +1,340 @@ +// Package pgx is a PostgreSQL database driver. +/* +pgx provides lower level access to PostgreSQL than the standard database/sql. It remains as similar to the database/sql +interface as possible while providing better speed and access to PostgreSQL specific features. Import +github.com/jackc/pgx/v4/stdlib to use pgx as a database/sql compatible driver. + +Establishing a Connection + +The primary way of establishing a connection is with `pgx.Connect`. + + conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) + +The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified +here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with +`ConnectConfig`. + + config, err := pgx.ParseConfig(os.Getenv("DATABASE_URL")) + if err != nil { + // ... + } + config.Logger = log15adapter.NewLogger(log.New("module", "pgx")) + + conn, err := pgx.ConnectConfig(context.Background(), config) + +Connection Pool + +`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use sub-package pgxpool for a +concurrency safe connection pool. + +Query Interface + +pgx implements Query and Scan in the familiar database/sql style. + + var sum int32 + + // Send the query to the server. The returned rows MUST be closed + // before conn can be used again. + rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) + if err != nil { + return err + } + + // rows.Close is called by rows.Next when all rows are read + // or an error occurs in Next or Scan. So it may optionally be + // omitted if nothing in the rows.Next loop can panic. It is + // safe to close rows multiple times. + defer rows.Close() + + // Iterate through the result set + for rows.Next() { + var n int32 + err = rows.Scan(&n) + if err != nil { + return err + } + sum += n + } + + // Any errors encountered by rows.Next or rows.Scan will be returned here + if rows.Err() != nil { + return rows.Err() + } + + // No errors found - do something with sum + +pgx also implements QueryRow in the same style as database/sql. + + var name string + var weight int64 + err := conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight) + if err != nil { + return err + } + +Use Exec to execute a query that does not return a result set. + + commandTag, err := conn.Exec(context.Background(), "delete from widgets where id=$1", 42) + if err != nil { + return err + } + if commandTag.RowsAffected() != 1 { + return errors.New("No row found to delete") + } + +QueryFunc can be used to execute a callback function for every row. This is often easier to use than Query. + + var sum, n int32 + _, err = conn.QueryFunc( + context.Background(), + "select generate_series(1,$1)", + []interface{}{10}, + []interface{}{&n}, + func(pgx.QueryFuncRow) error { + sum += n + return nil + }, + ) + if err != nil { + return err + } + +Base Type Mapping + +pgx maps between all common base types directly between Go and PostgreSQL. In particular: + + Go PostgreSQL + ----------------------- + string varchar + text + + // Integers are automatically be converted to any other integer type if + // it can be done without overflow or underflow. + int8 + int16 smallint + int32 int + int64 bigint + int + uint8 + uint16 + uint32 + uint64 + uint + + // Floats are strict and do not automatically convert like integers. + float32 float4 + float64 float8 + + time.Time date + timestamp + timestamptz + + []byte bytea + + +Null Mapping + +pgx can map nulls in two ways. The first is package pgtype provides types that have a data field and a status field. +They work in a similar fashion to database/sql. The second is to use a pointer to a pointer. + + var foo pgtype.Varchar + var bar *string + err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar) + if err != nil { + return err + } + +Array Mapping + +pgx maps between int16, int32, int64, float32, float64, and string Go slices and the equivalent PostgreSQL array type. +Go slices of native types do not support nulls, so if a PostgreSQL array that contains a null is read into a native Go +slice an error will occur. The pgtype package includes many more array types for PostgreSQL types that do not directly +map to native Go types. + +JSON and JSONB Mapping + +pgx includes built-in support to marshal and unmarshal between Go types and the PostgreSQL JSON and JSONB. + +Inet and CIDR Mapping + +pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In addition, as a convenience pgx will encode +from a net.IP; it will assume a /32 netmask for IPv4 and a /128 for IPv6. + +Custom Type Support + +pgx includes support for the common data types like integers, floats, strings, dates, and times that have direct +mappings between Go and SQL. In addition, pgx uses the github.com/jackc/pgtype library to support more types. See +documention for that library for instructions on how to implement custom types. + +See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. + +pgx also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer +interfaces. + +If pgx does cannot natively encode a type and that type is a renamed type (e.g. type MyTime time.Time) pgx will attempt +to encode the underlying type. While this is usually desired behavior it can produce surprising behavior if one the +underlying type and the renamed type each implement database/sql interfaces and the other implements pgx interfaces. It +is recommended that this situation be avoided by implementing pgx interfaces on the renamed type. + +Composite types and row values + +Row values and composite types are represented as pgtype.Record (https://pkg.go.dev/github.com/jackc/pgtype?tab=doc#Record). +It is possible to get values of your custom type by implementing DecodeBinary interface. Decoding into +pgtype.Record first can simplify process by avoiding dealing with raw protocol directly. + +For example: + + type MyType struct { + a int // NULL will cause decoding error + b *string // there can be NULL in this position in SQL + } + + func (t *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + r := pgtype.Record{ + Fields: []pgtype.Value{&pgtype.Int4{}, &pgtype.Text{}}, + } + + if err := r.DecodeBinary(ci, src); err != nil { + return err + } + + if r.Status != pgtype.Present { + return errors.New("BUG: decoding should not be called on NULL value") + } + + a := r.Fields[0].(*pgtype.Int4) + b := r.Fields[1].(*pgtype.Text) + + // type compatibility is checked by AssignTo + // only lossless assignments will succeed + if err := a.AssignTo(&t.a); err != nil { + return err + } + + // AssignTo also deals with null value handling + if err := b.AssignTo(&t.b); err != nil { + return err + } + return nil + } + + result := MyType{} + err := conn.QueryRow(context.Background(), "select row(1, 'foo'::text)", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&r) + +Raw Bytes Mapping + +[]byte passed as arguments to Query, QueryRow, and Exec are passed unmodified to PostgreSQL. + +Transactions + +Transactions are started by calling Begin. + + tx, err := conn.Begin(context.Background()) + if err != nil { + return err + } + // Rollback is safe to call even if the tx is already closed, so if + // the tx commits successfully, this is a no-op + defer tx.Rollback(context.Background()) + + _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") + if err != nil { + return err + } + + err = tx.Commit(context.Background()) + if err != nil { + return err + } + +The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions. +These are internally implemented with savepoints. + +Use BeginTx to control the transaction mode. + +BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the +transaction depending on the return value of the function. These can be simpler and less error prone to use. + + err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { + _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") + return err + }) + if err != nil { + return err + } + +Prepared Statements + +Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx +includes an automatic statement cache by default. Queries run through the normal Query, QueryRow, and Exec functions are +automatically prepared on first execution and the prepared statement is reused on subsequent executions. See ParseConfig +for information on how to customize or disable the statement cache. + +Copy Protocol + +Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a +CopyFromSource interface. If the data is already in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource +interface. Or implement CopyFromSource to avoid buffering the entire data set in memory. + + rows := [][]interface{}{ + {"John", "Smith", int32(36)}, + {"Jane", "Doe", int32(29)}, + } + + copyCount, err := conn.CopyFrom( + context.Background(), + pgx.Identifier{"people"}, + []string{"first_name", "last_name", "age"}, + pgx.CopyFromRows(rows), + ) + +When you already have a typed array using CopyFromSlice can be more convenient. + + rows := []User{ + {"John", "Smith", 36}, + {"Jane", "Doe", 29}, + } + + copyCount, err := conn.CopyFrom( + context.Background(), + pgx.Identifier{"people"}, + []string{"first_name", "last_name", "age"}, + pgx.CopyFromSlice(len(rows), func(i int) ([]interface{}, error) { + return []interface{}{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil + }), + ) + +CopyFrom can be faster than an insert with as few as 5 rows. + +Listen and Notify + +pgx can listen to the PostgreSQL notification system with the `Conn.WaitForNotification` method. It blocks until a +context is received or the context is canceled. + + _, err := conn.Exec(context.Background(), "listen channelname") + if err != nil { + return nil + } + + if notification, err := conn.WaitForNotification(context.Background()); err != nil { + // do something with notification + } + + +Logging + +pgx defines a simple logger interface. Connections optionally accept a logger that satisfies this interface. Set +LogLevel to control logging verbosity. Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus, +go.uber.org/zap, github.com/rs/zerolog, and the testing log are provided in the log directory. + +Lower Level PostgreSQL Functionality + +pgx is implemented on top of github.com/jackc/pgconn a lower level PostgreSQL driver. The Conn.PgConn() method can be +used to access this lower layer. + +PgBouncer + +pgx is compatible with PgBouncer in two modes. One is when the connection has a statement cache in "describe" mode. The +other is when the connection is using the simple protocol. This can be set with the PreferSimpleProtocol config option. +*/ +package pgx diff --git a/vendor/github.com/jackc/pgx/v4/extended_query_builder.go b/vendor/github.com/jackc/pgx/v4/extended_query_builder.go new file mode 100644 index 000000000..09419f0d0 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/extended_query_builder.go @@ -0,0 +1,168 @@ +package pgx + +import ( + "database/sql/driver" + "fmt" + "reflect" + + "github.com/jackc/pgtype" +) + +type extendedQueryBuilder struct { + paramValues [][]byte + paramValueBytes []byte + paramFormats []int16 + resultFormats []int16 + + resetCount int +} + +func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error { + f := chooseParameterFormatCode(ci, oid, arg) + eqb.paramFormats = append(eqb.paramFormats, f) + + v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg) + if err != nil { + return err + } + eqb.paramValues = append(eqb.paramValues, v) + + return nil +} + +func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) { + eqb.resultFormats = append(eqb.resultFormats, f) +} + +func (eqb *extendedQueryBuilder) Reset() { + eqb.paramValues = eqb.paramValues[0:0] + eqb.paramValueBytes = eqb.paramValueBytes[0:0] + eqb.paramFormats = eqb.paramFormats[0:0] + eqb.resultFormats = eqb.resultFormats[0:0] + + eqb.resetCount++ + + // Every so often shrink our reserved memory if it is abnormally high + if eqb.resetCount%128 == 0 { + if cap(eqb.paramValues) > 64 { + eqb.paramValues = make([][]byte, 0, cap(eqb.paramValues)/2) + } + + if cap(eqb.paramValueBytes) > 256 { + eqb.paramValueBytes = make([]byte, 0, cap(eqb.paramValueBytes)/2) + } + + if cap(eqb.paramFormats) > 64 { + eqb.paramFormats = make([]int16, 0, cap(eqb.paramFormats)/2) + } + if cap(eqb.resultFormats) > 64 { + eqb.resultFormats = make([]int16, 0, cap(eqb.resultFormats)/2) + } + } + +} + +func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { + if arg == nil { + return nil, nil + } + + refVal := reflect.ValueOf(arg) + argIsPtr := refVal.Kind() == reflect.Ptr + + if argIsPtr && refVal.IsNil() { + return nil, nil + } + + if eqb.paramValueBytes == nil { + eqb.paramValueBytes = make([]byte, 0, 128) + } + + var err error + var buf []byte + pos := len(eqb.paramValueBytes) + + if arg, ok := arg.(string); ok { + return []byte(arg), nil + } + + if formatCode == TextFormatCode { + if arg, ok := arg.(pgtype.TextEncoder); ok { + buf, err = arg.EncodeText(ci, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil + } + } else if formatCode == BinaryFormatCode { + if arg, ok := arg.(pgtype.BinaryEncoder); ok { + buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil + } + } + + if argIsPtr { + // We have already checked that arg is not pointing to nil, + // so it is safe to dereference here. + arg = refVal.Elem().Interface() + return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg) + } + + if dt, ok := ci.DataTypeForOID(oid); ok { + value := dt.Value + err := value.Set(arg) + if err != nil { + { + if arg, ok := arg.(driver.Valuer); ok { + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) + } + } + + return nil, err + } + + return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) + } + + // There is no data type registered for the destination OID, but maybe there is data type registered for the arg + // type. If so use it's text encoder (if available). + if dt, ok := ci.DataTypeForValue(arg); ok { + value := dt.Value + if textEncoder, ok := value.(pgtype.TextEncoder); ok { + err := value.Set(arg) + if err != nil { + return nil, err + } + + buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil + } + } + + if strippedArg, ok := stripNamedType(&refVal); ok { + return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg) + } + return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) +} diff --git a/vendor/github.com/jackc/pgx/v4/go.mod b/vendor/github.com/jackc/pgx/v4/go.mod new file mode 100644 index 000000000..2d877032b --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/go.mod @@ -0,0 +1,21 @@ +module github.com/jackc/pgx/v4 + +go 1.13 + +require ( + github.com/Masterminds/semver/v3 v3.1.1 + github.com/cockroachdb/apd v1.1.0 + github.com/go-kit/log v0.1.0 + github.com/gofrs/uuid v4.0.0+incompatible + github.com/jackc/pgconn v1.10.0 + github.com/jackc/pgio v1.0.0 + github.com/jackc/pgproto3/v2 v2.1.1 + github.com/jackc/pgtype v1.8.1 + github.com/jackc/puddle v1.1.3 + github.com/rs/zerolog v1.15.0 + github.com/shopspring/decimal v1.2.0 + github.com/sirupsen/logrus v1.4.2 + github.com/stretchr/testify v1.7.0 + go.uber.org/zap v1.13.0 + gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec +) diff --git a/vendor/github.com/jackc/pgx/v4/go.sum b/vendor/github.com/jackc/pgx/v4/go.sum new file mode 100644 index 000000000..2222449d8 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/go.sum @@ -0,0 +1,196 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-kit/log v0.1.0 h1:DGJh0Sm43HbOeYDNnVZFl8BvcYVvjD5bqYJvp0REbwQ= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.10.0 h1:4EYhlDVEMsJ30nNj0mmgwIUXoq7e9sMJrVC2ED6QlCU= +github.com/jackc/pgconn v1.10.0/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.8.1 h1:9k0IXtdJXHJbyAWQgbWr1lU+MEhPXZz6RIXxfR5oxXs= +github.com/jackc/pgtype v1.8.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3 h1:JnPg/5Q9xVJGfjsO5CPUOjnJps1JaRUm8I9FXVCFK94= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0 h1:uPRuwkWF4J6fGsJ2R0Gn2jB1EQiav9k3S6CSdygQJXY= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0 h1:nR6NoDBgAf67s68NhaXbsojM+2gxp3S1hWkHDl27pVU= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114 h1:DnSr2mCsxyCE6ZgIkmcWUQY2R5cH/6wL7eIxEmQOMSE= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec h1:RlWgLqCMMIYYEVcAR5MDsuHlVkaIPDAF+5Dehzg8L5A= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/vendor/github.com/jackc/pgx/v4/go_stdlib.go b/vendor/github.com/jackc/pgx/v4/go_stdlib.go new file mode 100644 index 000000000..9372f9efa --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/go_stdlib.go @@ -0,0 +1,61 @@ +package pgx + +import ( + "database/sql/driver" + "reflect" +) + +// This file contains code copied from the Go standard library due to the +// required function not being public. + +// Copyright (c) 2009 The Go Authors. All rights reserved. + +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: + +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. + +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// From database/sql/convert.go + +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// Issue 8415. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This function is mirrored in the database/sql/driver package. +func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +} diff --git a/vendor/github.com/jackc/pgx/v4/internal/sanitize/sanitize.go b/vendor/github.com/jackc/pgx/v4/internal/sanitize/sanitize.go new file mode 100644 index 000000000..2dba3b810 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/internal/sanitize/sanitize.go @@ -0,0 +1,304 @@ +package sanitize + +import ( + "bytes" + "encoding/hex" + "fmt" + "strconv" + "strings" + "time" + "unicode/utf8" +) + +// Part is either a string or an int. A string is raw SQL. An int is a +// argument placeholder. +type Part interface{} + +type Query struct { + Parts []Part +} + +func (q *Query) Sanitize(args ...interface{}) (string, error) { + argUse := make([]bool, len(args)) + buf := &bytes.Buffer{} + + for _, part := range q.Parts { + var str string + switch part := part.(type) { + case string: + str = part + case int: + argIdx := part - 1 + if argIdx >= len(args) { + return "", fmt.Errorf("insufficient arguments") + } + arg := args[argIdx] + switch arg := arg.(type) { + case nil: + str = "null" + case int64: + str = strconv.FormatInt(arg, 10) + case float64: + str = strconv.FormatFloat(arg, 'f', -1, 64) + case bool: + str = strconv.FormatBool(arg) + case []byte: + str = QuoteBytes(arg) + case string: + str = QuoteString(arg) + case time.Time: + str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + default: + return "", fmt.Errorf("invalid arg type: %T", arg) + } + argUse[argIdx] = true + default: + return "", fmt.Errorf("invalid Part type: %T", part) + } + buf.WriteString(str) + } + + for i, used := range argUse { + if !used { + return "", fmt.Errorf("unused argument: %d", i) + } + } + return buf.String(), nil +} + +func NewQuery(sql string) (*Query, error) { + l := &sqlLexer{ + src: sql, + stateFn: rawState, + } + + for l.stateFn != nil { + l.stateFn = l.stateFn(l) + } + + query := &Query{Parts: l.parts} + + return query, nil +} + +func QuoteString(str string) string { + return "'" + strings.ReplaceAll(str, "'", "''") + "'" +} + +func QuoteBytes(buf []byte) string { + return `'\x` + hex.EncodeToString(buf) + "'" +} + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []Part +} + +type stateFn func(*sqlLexer) stateFn + +func rawState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case 'e', 'E': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '\'' { + l.pos += width + return escapeStringState + } + case '\'': + return singleQuoteState + case '"': + return doubleQuoteState + case '$': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + if '0' <= nextRune && nextRune <= '9' { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return placeholderState + } + case '-': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '-' { + l.pos += width + return oneLineCommentState + } + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + return multilineCommentState + } + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func singleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func doubleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '"': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '"' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +// placeholderState consumes a placeholder value. The $ must have already has +// already been consumed. The first rune must be a digit. +func placeholderState(l *sqlLexer) stateFn { + num := 0 + + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + if '0' <= r && r <= '9' { + num *= 10 + num += int(r - '0') + } else { + l.parts = append(l.parts, num) + l.pos -= width + l.start = l.pos + return rawState + } + } +} + +func escapeStringState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func oneLineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\n': + return rawState + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func multilineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + l.nested++ + } + case '*': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '/' { + continue + } + + l.pos += width + if l.nested == 0 { + return rawState + } + l.nested-- + + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +// SanitizeSQL replaces placeholder values with args. It quotes and escapes args +// as necessary. This function is only safe when standard_conforming_strings is +// on. +func SanitizeSQL(sql string, args ...interface{}) (string, error) { + query, err := NewQuery(sql) + if err != nil { + return "", err + } + return query.Sanitize(args...) +} diff --git a/vendor/github.com/jackc/pgx/v4/large_objects.go b/vendor/github.com/jackc/pgx/v4/large_objects.go new file mode 100644 index 000000000..5255a3b48 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/large_objects.go @@ -0,0 +1,121 @@ +package pgx + +import ( + "context" + "errors" + "io" +) + +// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it +// was created. +// +// For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html +type LargeObjects struct { + tx Tx +} + +type LargeObjectMode int32 + +const ( + LargeObjectModeWrite LargeObjectMode = 0x20000 + LargeObjectModeRead LargeObjectMode = 0x40000 +) + +// Create creates a new large object. If oid is zero, the server assigns an unused OID. +func (o *LargeObjects) Create(ctx context.Context, oid uint32) (uint32, error) { + err := o.tx.QueryRow(ctx, "select lo_create($1)", oid).Scan(&oid) + return oid, err +} + +// Open opens an existing large object with the given mode. ctx will also be used for all operations on the opened large +// object. +func (o *LargeObjects) Open(ctx context.Context, oid uint32, mode LargeObjectMode) (*LargeObject, error) { + var fd int32 + err := o.tx.QueryRow(ctx, "select lo_open($1, $2)", oid, mode).Scan(&fd) + if err != nil { + return nil, err + } + return &LargeObject{fd: fd, tx: o.tx, ctx: ctx}, nil +} + +// Unlink removes a large object from the database. +func (o *LargeObjects) Unlink(ctx context.Context, oid uint32) error { + var result int32 + err := o.tx.QueryRow(ctx, "select lo_unlink($1)", oid).Scan(&result) + if err != nil { + return err + } + + if result != 1 { + return errors.New("failed to remove large object") + } + + return nil +} + +// A LargeObject is a large object stored on the server. It is only valid within the transaction that it was initialized +// in. It uses the context it was initialized with for all operations. It implements these interfaces: +// +// io.Writer +// io.Reader +// io.Seeker +// io.Closer +type LargeObject struct { + ctx context.Context + tx Tx + fd int32 +} + +// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. +func (o *LargeObject) Write(p []byte) (int, error) { + var n int + err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n) + if err != nil { + return n, err + } + + if n < 0 { + return 0, errors.New("failed to write to large object") + } + + return n, nil +} + +// Read reads up to len(p) bytes into p returning the number of bytes read. +func (o *LargeObject) Read(p []byte) (int, error) { + var res []byte + err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res) + copy(p, res) + if err != nil { + return len(res), err + } + + if len(res) < len(p) { + err = io.EOF + } + return len(res), err +} + +// Seek moves the current location pointer to the new location specified by offset. +func (o *LargeObject) Seek(offset int64, whence int) (n int64, err error) { + err = o.tx.QueryRow(o.ctx, "select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n) + return n, err +} + +// Tell returns the current read or write location of the large object descriptor. +func (o *LargeObject) Tell() (n int64, err error) { + err = o.tx.QueryRow(o.ctx, "select lo_tell64($1)", o.fd).Scan(&n) + return n, err +} + +// Trunctes the large object to size. +func (o *LargeObject) Truncate(size int64) (err error) { + _, err = o.tx.Exec(o.ctx, "select lo_truncate64($1, $2)", o.fd, size) + return err +} + +// Close closees the large object descriptor. +func (o *LargeObject) Close() error { + _, err := o.tx.Exec(o.ctx, "select lo_close($1)", o.fd) + return err +} diff --git a/vendor/github.com/jackc/pgx/v4/logger.go b/vendor/github.com/jackc/pgx/v4/logger.go new file mode 100644 index 000000000..89fd5af51 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/logger.go @@ -0,0 +1,98 @@ +package pgx + +import ( + "context" + "encoding/hex" + "errors" + "fmt" +) + +// The values for log levels are chosen such that the zero value means that no +// log level was specified. +const ( + LogLevelTrace = 6 + LogLevelDebug = 5 + LogLevelInfo = 4 + LogLevelWarn = 3 + LogLevelError = 2 + LogLevelNone = 1 +) + +// LogLevel represents the pgx logging level. See LogLevel* constants for +// possible values. +type LogLevel int + +func (ll LogLevel) String() string { + switch ll { + case LogLevelTrace: + return "trace" + case LogLevelDebug: + return "debug" + case LogLevelInfo: + return "info" + case LogLevelWarn: + return "warn" + case LogLevelError: + return "error" + case LogLevelNone: + return "none" + default: + return fmt.Sprintf("invalid level %d", ll) + } +} + +// Logger is the interface used to get logging from pgx internals. +type Logger interface { + // Log a message at the given level with data key/value pairs. data may be nil. + Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) +} + +// LogLevelFromString converts log level string to constant +// +// Valid levels: +// trace +// debug +// info +// warn +// error +// none +func LogLevelFromString(s string) (LogLevel, error) { + switch s { + case "trace": + return LogLevelTrace, nil + case "debug": + return LogLevelDebug, nil + case "info": + return LogLevelInfo, nil + case "warn": + return LogLevelWarn, nil + case "error": + return LogLevelError, nil + case "none": + return LogLevelNone, nil + default: + return 0, errors.New("invalid log level") + } +} + +func logQueryArgs(args []interface{}) []interface{} { + logArgs := make([]interface{}, 0, len(args)) + + for _, a := range args { + switch v := a.(type) { + case []byte: + if len(v) < 64 { + a = hex.EncodeToString(v) + } else { + a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64) + } + case string: + if len(v) > 64 { + a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64) + } + } + logArgs = append(logArgs, a) + } + + return logArgs +} diff --git a/vendor/github.com/jackc/pgx/v4/messages.go b/vendor/github.com/jackc/pgx/v4/messages.go new file mode 100644 index 000000000..5324cbb5c --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/messages.go @@ -0,0 +1,23 @@ +package pgx + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +func convertDriverValuers(args []interface{}) ([]interface{}, error) { + for i, arg := range args { + switch arg := arg.(type) { + case pgtype.BinaryEncoder: + case pgtype.TextEncoder: + case driver.Valuer: + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + args[i] = v + } + } + return args, nil +} diff --git a/vendor/github.com/jackc/pgx/v4/rows.go b/vendor/github.com/jackc/pgx/v4/rows.go new file mode 100644 index 000000000..d57d5cbf6 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/rows.go @@ -0,0 +1,347 @@ +package pgx + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgtype" +) + +// Rows is the result set returned from *Conn.Query. Rows must be closed before +// the *Conn can be used again. Rows are closed by explicitly calling Close(), +// calling Next() until it returns false, or when a fatal error occurs. +// +// Once a Rows is closed the only methods that may be called are Close(), Err(), and CommandTag(). +// +// Rows is an interface instead of a struct to allow tests to mock Query. However, +// adding a method to an interface is technically a breaking change. Because of this +// the Rows interface is partially excluded from semantic version requirements. +// Methods will not be removed or changed, but new methods may be added. +type Rows interface { + // Close closes the rows, making the connection ready for use again. It is safe + // to call Close after rows is already closed. + Close() + + // Err returns any error that occurred while reading. + Err() error + + // CommandTag returns the command tag from this query. It is only available after Rows is closed. + CommandTag() pgconn.CommandTag + + FieldDescriptions() []pgproto3.FieldDescription + + // Next prepares the next row for reading. It returns true if there is another + // row and false if no more rows are available. It automatically closes rows + // when all rows are read. + Next() bool + + // Scan reads the values from the current row into dest values positionally. + // dest can include pointers to core types, values implementing the Scanner + // interface, and nil. nil will skip the value entirely. + Scan(dest ...interface{}) error + + // Values returns the decoded row values. + Values() ([]interface{}, error) + + // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid until the next Next + // call or the Rows is closed. However, the underlying byte data is safe to retain a reference to and mutate. + RawValues() [][]byte +} + +// Row is a convenience wrapper over Rows that is returned by QueryRow. +// +// Row is an interface instead of a struct to allow tests to mock QueryRow. However, +// adding a method to an interface is technically a breaking change. Because of this +// the Row interface is partially excluded from semantic version requirements. +// Methods will not be removed or changed, but new methods may be added. +type Row interface { + // Scan works the same as Rows. with the following exceptions. If no + // rows were found it returns ErrNoRows. If multiple rows are returned it + // ignores all but the first. + Scan(dest ...interface{}) error +} + +// connRow implements the Row interface for Conn.QueryRow. +type connRow connRows + +func (r *connRow) Scan(dest ...interface{}) (err error) { + rows := (*connRows)(r) + + if rows.Err() != nil { + return rows.Err() + } + + if !rows.Next() { + if rows.Err() == nil { + return ErrNoRows + } + return rows.Err() + } + + rows.Scan(dest...) + rows.Close() + return rows.Err() +} + +type rowLog interface { + shouldLog(lvl LogLevel) bool + log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) +} + +// connRows implements the Rows interface for Conn.Query. +type connRows struct { + ctx context.Context + logger rowLog + connInfo *pgtype.ConnInfo + values [][]byte + rowCount int + err error + commandTag pgconn.CommandTag + startTime time.Time + sql string + args []interface{} + closed bool + conn *Conn + + resultReader *pgconn.ResultReader + multiResultReader *pgconn.MultiResultReader + + scanPlans []pgtype.ScanPlan +} + +func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription { + return rows.resultReader.FieldDescriptions() +} + +func (rows *connRows) Close() { + if rows.closed { + return + } + + rows.closed = true + + if rows.resultReader != nil { + var closeErr error + rows.commandTag, closeErr = rows.resultReader.Close() + if rows.err == nil { + rows.err = closeErr + } + } + + if rows.multiResultReader != nil { + closeErr := rows.multiResultReader.Close() + if rows.err == nil { + rows.err = closeErr + } + } + + if rows.logger != nil { + if rows.err == nil { + if rows.logger.shouldLog(LogLevelInfo) { + endTime := time.Now() + rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) + } + } else { + if rows.logger.shouldLog(LogLevelError) { + rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "args": logQueryArgs(rows.args)}) + } + if rows.err != nil && rows.conn.stmtcache != nil { + rows.conn.stmtcache.StatementErrored(rows.sql, rows.err) + } + } + } +} + +func (rows *connRows) CommandTag() pgconn.CommandTag { + return rows.commandTag +} + +func (rows *connRows) Err() error { + return rows.err +} + +// fatal signals an error occurred after the query was sent to the server. It +// closes the rows automatically. +func (rows *connRows) fatal(err error) { + if rows.err != nil { + return + } + + rows.err = err + rows.Close() +} + +func (rows *connRows) Next() bool { + if rows.closed { + return false + } + + if rows.resultReader.NextRow() { + rows.rowCount++ + rows.values = rows.resultReader.Values() + return true + } else { + rows.Close() + return false + } +} + +func (rows *connRows) Scan(dest ...interface{}) error { + ci := rows.connInfo + fieldDescriptions := rows.FieldDescriptions() + values := rows.values + + if len(fieldDescriptions) != len(values) { + err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) + rows.fatal(err) + return err + } + if len(fieldDescriptions) != len(dest) { + err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) + rows.fatal(err) + return err + } + + if rows.scanPlans == nil { + rows.scanPlans = make([]pgtype.ScanPlan, len(values)) + for i := range dest { + rows.scanPlans[i] = ci.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) + } + } + + for i, dst := range dest { + if dst == nil { + continue + } + + err := rows.scanPlans[i].Scan(ci, fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], dst) + if err != nil { + err = ScanArgError{ColumnIndex: i, Err: err} + rows.fatal(err) + return err + } + } + + return nil +} + +func (rows *connRows) Values() ([]interface{}, error) { + if rows.closed { + return nil, errors.New("rows is closed") + } + + values := make([]interface{}, 0, len(rows.FieldDescriptions())) + + for i := range rows.FieldDescriptions() { + buf := rows.values[i] + fd := &rows.FieldDescriptions()[i] + + if buf == nil { + values = append(values, nil) + continue + } + + if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok { + value := dt.Value + + switch fd.Format { + case TextFormatCode: + decoder, ok := value.(pgtype.TextDecoder) + if !ok { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.connInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, decoder.(pgtype.Value).Get()) + case BinaryFormatCode: + decoder, ok := value.(pgtype.BinaryDecoder) + if !ok { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.connInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, value.Get()) + default: + rows.fatal(errors.New("Unknown format code")) + } + } else { + switch fd.Format { + case TextFormatCode: + decoder := &pgtype.GenericText{} + err := decoder.DecodeText(rows.connInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, decoder.Get()) + case BinaryFormatCode: + decoder := &pgtype.GenericBinary{} + err := decoder.DecodeBinary(rows.connInfo, buf) + if err != nil { + rows.fatal(err) + } + values = append(values, decoder.Get()) + default: + rows.fatal(errors.New("Unknown format code")) + } + } + + if rows.Err() != nil { + return nil, rows.Err() + } + } + + return values, rows.Err() +} + +func (rows *connRows) RawValues() [][]byte { + return rows.values +} + +type ScanArgError struct { + ColumnIndex int + Err error +} + +func (e ScanArgError) Error() string { + return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) +} + +func (e ScanArgError) Unwrap() error { + return e.Err +} + +// ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface. +// +// connInfo - OID to Go type mapping. +// fieldDescriptions - OID and format of values +// values - the raw data as returned from the PostgreSQL server +// dest - the destination that values will be decoded into +func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error { + if len(fieldDescriptions) != len(values) { + return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) + } + if len(fieldDescriptions) != len(dest) { + return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) + } + + for i, d := range dest { + if d == nil { + continue + } + + err := connInfo.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) + if err != nil { + return ScanArgError{ColumnIndex: i, Err: err} + } + } + + return nil +} diff --git a/vendor/github.com/jackc/pgx/v4/stdlib/sql.go b/vendor/github.com/jackc/pgx/v4/stdlib/sql.go new file mode 100644 index 000000000..fa81e73d5 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/stdlib/sql.go @@ -0,0 +1,858 @@ +// Package stdlib is the compatibility layer from pgx to database/sql. +// +// A database/sql connection can be established through sql.Open. +// +// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") +// if err != nil { +// return err +// } +// +// Or from a DSN string. +// +// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") +// if err != nil { +// return err +// } +// +// Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the +// pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used +// with sql.Open. +// +// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) +// connConfig.Logger = myLogger +// connStr := stdlib.RegisterConnConfig(connConfig) +// db, _ := sql.Open("pgx", connStr) +// +// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. +// It does not support named parameters. +// +// db.QueryRow("select * from users where id=$1", userID) +// +// In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard +// database/sql.DB connection pool. This allows operations that use pgx specific functionality. +// +// // Given db is a *sql.DB +// conn, err := db.Conn(context.Background()) +// if err != nil { +// // handle error from acquiring connection from DB pool +// } +// +// err = conn.Raw(func(driverConn interface{}) error { +// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn +// // Do pgx specific stuff with conn +// conn.CopyFrom(...) +// return nil +// }) +// if err != nil { +// // handle error that occurred while using *pgx.Conn +// } +package stdlib + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io" + "math" + "math/rand" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" +) + +// Only intrinsic types should be binary format with database/sql. +var databaseSQLResultFormats pgx.QueryResultFormatsByOID + +var pgxDriver *Driver + +type ctxKey int + +var ctxKeyFakeTx ctxKey = 0 + +var ErrNotPgx = errors.New("not pgx *sql.DB") + +func init() { + pgxDriver = &Driver{ + configs: make(map[string]*pgx.ConnConfig), + } + fakeTxConns = make(map[*pgx.Conn]*sql.Tx) + sql.Register("pgx", pgxDriver) + + databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ + pgtype.BoolOID: 1, + pgtype.ByteaOID: 1, + pgtype.CIDOID: 1, + pgtype.DateOID: 1, + pgtype.Float4OID: 1, + pgtype.Float8OID: 1, + pgtype.Int2OID: 1, + pgtype.Int4OID: 1, + pgtype.Int8OID: 1, + pgtype.OIDOID: 1, + pgtype.TimestampOID: 1, + pgtype.TimestamptzOID: 1, + pgtype.XIDOID: 1, + } +} + +var ( + fakeTxMutex sync.Mutex + fakeTxConns map[*pgx.Conn]*sql.Tx +) + +// OptionOpenDB options for configuring the driver when opening a new db pool. +type OptionOpenDB func(*connector) + +// OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will +// be used to connect, so only its immediate members should be modified. +func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB { + return func(dc *connector) { + dc.BeforeConnect = bc + } +} + +// OptionAfterConnect provides a callback for after connect. +func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB { + return func(dc *connector) { + dc.AfterConnect = ac + } +} + +// OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the +// connection if the connection has been used before. +// If ResetSessionFunc returns ErrBadConn error the connection will be discarded. +func OptionResetSession(rs func(context.Context, *pgx.Conn) error) OptionOpenDB { + return func(dc *connector) { + dc.ResetSession = rs + } +} + +// RandomizeHostOrderFunc is a BeforeConnect hook that randomizes the host order in the provided connConfig, so that a +// new host becomes primary each time. This is useful to distribute connections for multi-master databases like +// CockroachDB. If you use this you likely should set https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime as well +// to ensure that connections are periodically rebalanced across your nodes. +func RandomizeHostOrderFunc(ctx context.Context, connConfig *pgx.ConnConfig) error { + if len(connConfig.Fallbacks) == 0 { + return nil + } + + newFallbacks := append([]*pgconn.FallbackConfig{&pgconn.FallbackConfig{ + Host: connConfig.Host, + Port: connConfig.Port, + TLSConfig: connConfig.TLSConfig, + }}, connConfig.Fallbacks...) + + rand.Shuffle(len(newFallbacks), func(i, j int) { + newFallbacks[i], newFallbacks[j] = newFallbacks[j], newFallbacks[i] + }) + + // Use the one that sorted last as the primary and keep the rest as the fallbacks + newPrimary := newFallbacks[len(newFallbacks)-1] + connConfig.Host = newPrimary.Host + connConfig.Port = newPrimary.Port + connConfig.TLSConfig = newPrimary.TLSConfig + connConfig.Fallbacks = newFallbacks[:len(newFallbacks)-1] + return nil +} + +func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { + c := connector{ + ConnConfig: config, + BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil }, // noop before connect by default + AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default + ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default + driver: pgxDriver, + } + + for _, opt := range opts { + opt(&c) + } + + return sql.OpenDB(c) +} + +type connector struct { + pgx.ConnConfig + BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection + AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection + ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused + driver *Driver +} + +// Connect implement driver.Connector interface +func (c connector) Connect(ctx context.Context) (driver.Conn, error) { + var ( + err error + conn *pgx.Conn + ) + + // Create a shallow copy of the config, so that BeforeConnect can safely modify it + connConfig := c.ConnConfig + if err = c.BeforeConnect(ctx, &connConfig); err != nil { + return nil, err + } + + if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil { + return nil, err + } + + if err = c.AfterConnect(ctx, conn); err != nil { + return nil, err + } + + return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil +} + +// Driver implement driver.Connector interface +func (c connector) Driver() driver.Driver { + return c.driver +} + +// GetDefaultDriver returns the driver initialized in the init function +// and used when the pgx driver is registered. +func GetDefaultDriver() driver.Driver { + return pgxDriver +} + +type Driver struct { + configMutex sync.Mutex + configs map[string]*pgx.ConnConfig + sequence int +} + +func (d *Driver) Open(name string) (driver.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout + defer cancel() + + connector, err := d.OpenConnector(name) + if err != nil { + return nil, err + } + return connector.Connect(ctx) +} + +func (d *Driver) OpenConnector(name string) (driver.Connector, error) { + return &driverConnector{driver: d, name: name}, nil +} + +func (d *Driver) registerConnConfig(c *pgx.ConnConfig) string { + d.configMutex.Lock() + connStr := fmt.Sprintf("registeredConnConfig%d", d.sequence) + d.sequence++ + d.configs[connStr] = c + d.configMutex.Unlock() + return connStr +} + +func (d *Driver) unregisterConnConfig(connStr string) { + d.configMutex.Lock() + delete(d.configs, connStr) + d.configMutex.Unlock() +} + +type driverConnector struct { + driver *Driver + name string +} + +func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) { + var connConfig *pgx.ConnConfig + + dc.driver.configMutex.Lock() + connConfig = dc.driver.configs[dc.name] + dc.driver.configMutex.Unlock() + + if connConfig == nil { + var err error + connConfig, err = pgx.ParseConfig(dc.name) + if err != nil { + return nil, err + } + } + + conn, err := pgx.ConnectConfig(ctx, connConfig) + if err != nil { + return nil, err + } + + c := &Conn{ + conn: conn, + driver: dc.driver, + connConfig: *connConfig, + resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil }, + } + + return c, nil +} + +func (dc *driverConnector) Driver() driver.Driver { + return dc.driver +} + +// RegisterConnConfig registers a ConnConfig and returns the connection string to use with Open. +func RegisterConnConfig(c *pgx.ConnConfig) string { + return pgxDriver.registerConnConfig(c) +} + +// UnregisterConnConfig removes the ConnConfig registration for connStr. +func UnregisterConnConfig(connStr string) { + pgxDriver.unregisterConnConfig(connStr) +} + +type Conn struct { + conn *pgx.Conn + psCount int64 // Counter used for creating unique prepared statement names + driver *Driver + connConfig pgx.ConnConfig + resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused +} + +// Conn returns the underlying *pgx.Conn +func (c *Conn) Conn() *pgx.Conn { + return c.conn +} + +func (c *Conn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if c.conn.IsClosed() { + return nil, driver.ErrBadConn + } + + name := fmt.Sprintf("pgx_%d", c.psCount) + c.psCount++ + + sd, err := c.conn.Prepare(ctx, name, query) + if err != nil { + return nil, err + } + + return &Stmt{sd: sd, conn: c}, nil +} + +func (c *Conn) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + return c.conn.Close(ctx) +} + +func (c *Conn) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if c.conn.IsClosed() { + return nil, driver.ErrBadConn + } + + if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok { + *pconn = c.conn + return fakeTx{}, nil + } + + var pgxOpts pgx.TxOptions + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + case sql.LevelReadUncommitted: + pgxOpts.IsoLevel = pgx.ReadUncommitted + case sql.LevelReadCommitted: + pgxOpts.IsoLevel = pgx.ReadCommitted + case sql.LevelRepeatableRead, sql.LevelSnapshot: + pgxOpts.IsoLevel = pgx.RepeatableRead + case sql.LevelSerializable: + pgxOpts.IsoLevel = pgx.Serializable + default: + return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation) + } + + if opts.ReadOnly { + pgxOpts.AccessMode = pgx.ReadOnly + } + + tx, err := c.conn.BeginTx(ctx, pgxOpts) + if err != nil { + return nil, err + } + + return wrapTx{ctx: ctx, tx: tx}, nil +} + +func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) { + if c.conn.IsClosed() { + return nil, driver.ErrBadConn + } + + args := namedValueToInterface(argsV) + + commandTag, err := c.conn.Exec(ctx, query, args...) + // if we got a network error before we had a chance to send the query, retry + if err != nil { + if pgconn.SafeToRetry(err) { + return nil, driver.ErrBadConn + } + } + return driver.RowsAffected(commandTag.RowsAffected()), err +} + +func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { + if c.conn.IsClosed() { + return nil, driver.ErrBadConn + } + + args := []interface{}{databaseSQLResultFormats} + args = append(args, namedValueToInterface(argsV)...) + + rows, err := c.conn.Query(ctx, query, args...) + if err != nil { + if pgconn.SafeToRetry(err) { + return nil, driver.ErrBadConn + } + return nil, err + } + + // Preload first row because otherwise we won't know what columns are available when database/sql asks. + more := rows.Next() + if err = rows.Err(); err != nil { + rows.Close() + return nil, err + } + return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil +} + +func (c *Conn) Ping(ctx context.Context) error { + if c.conn.IsClosed() { + return driver.ErrBadConn + } + + err := c.conn.Ping(ctx) + if err != nil { + // A Ping failure implies some sort of fatal state. The connection is almost certainly already closed by the + // failure, but manually close it just to be sure. + c.Close() + return driver.ErrBadConn + } + + return nil +} + +func (c *Conn) CheckNamedValue(*driver.NamedValue) error { + // Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. So everything can be passed through directly. + return nil +} + +func (c *Conn) ResetSession(ctx context.Context) error { + if c.conn.IsClosed() { + return driver.ErrBadConn + } + + return c.resetSessionFunc(ctx, c.conn) +} + +type Stmt struct { + sd *pgconn.StatementDescription + conn *Conn +} + +func (s *Stmt) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + return s.conn.conn.Deallocate(ctx, s.sd.Name) +} + +func (s *Stmt) NumInput() int { + return len(s.sd.ParamOIDs) +} + +func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { + return nil, errors.New("Stmt.Exec deprecated and not implemented") +} + +func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) { + return s.conn.ExecContext(ctx, s.sd.Name, argsV) +} + +func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { + return nil, errors.New("Stmt.Query deprecated and not implemented") +} + +func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) { + return s.conn.QueryContext(ctx, s.sd.Name, argsV) +} + +type rowValueFunc func(src []byte) (driver.Value, error) + +type Rows struct { + conn *Conn + rows pgx.Rows + valueFuncs []rowValueFunc + skipNext bool + skipNextMore bool + + columnNames []string +} + +func (r *Rows) Columns() []string { + if r.columnNames == nil { + fields := r.rows.FieldDescriptions() + r.columnNames = make([]string, len(fields)) + for i, fd := range fields { + r.columnNames[i] = string(fd.Name) + } + } + + return r.columnNames +} + +// ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned. +func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { + if dt, ok := r.conn.conn.ConnInfo().DataTypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { + return strings.ToUpper(dt.Name) + } + + return strconv.FormatInt(int64(r.rows.FieldDescriptions()[index].DataTypeOID), 10) +} + +const varHeaderSize = 4 + +// ColumnTypeLength returns the length of the column type if the column is a +// variable length type. If the column is not a variable length type ok +// should return false. +func (r *Rows) ColumnTypeLength(index int) (int64, bool) { + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.TextOID, pgtype.ByteaOID: + return math.MaxInt64, true + case pgtype.VarcharOID, pgtype.BPCharArrayOID: + return int64(fd.TypeModifier - varHeaderSize), true + default: + return 0, false + } +} + +// ColumnTypePrecisionScale should return the precision and scale for decimal +// types. If not applicable, ok should be false. +func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.NumericOID: + mod := fd.TypeModifier - varHeaderSize + precision = int64((mod >> 16) & 0xffff) + scale = int64(mod & 0xffff) + return precision, scale, true + default: + return 0, 0, false + } +} + +// ColumnTypeScanType returns the value type that can be used to scan types into. +func (r *Rows) ColumnTypeScanType(index int) reflect.Type { + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.Float8OID: + return reflect.TypeOf(float64(0)) + case pgtype.Float4OID: + return reflect.TypeOf(float32(0)) + case pgtype.Int8OID: + return reflect.TypeOf(int64(0)) + case pgtype.Int4OID: + return reflect.TypeOf(int32(0)) + case pgtype.Int2OID: + return reflect.TypeOf(int16(0)) + case pgtype.BoolOID: + return reflect.TypeOf(false) + case pgtype.NumericOID: + return reflect.TypeOf(float64(0)) + case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID: + return reflect.TypeOf(time.Time{}) + case pgtype.ByteaOID: + return reflect.TypeOf([]byte(nil)) + default: + return reflect.TypeOf("") + } +} + +func (r *Rows) Close() error { + r.rows.Close() + return r.rows.Err() +} + +func (r *Rows) Next(dest []driver.Value) error { + ci := r.conn.conn.ConnInfo() + fieldDescriptions := r.rows.FieldDescriptions() + + if r.valueFuncs == nil { + r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions)) + + for i, fd := range fieldDescriptions { + dataTypeOID := fd.DataTypeOID + format := fd.Format + + switch fd.DataTypeOID { + case pgtype.BoolOID: + var d bool + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + return d, err + } + case pgtype.ByteaOID: + var d []byte + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + return d, err + } + case pgtype.CIDOID: + var d pgtype.CID + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + case pgtype.DateOID: + var d pgtype.Date + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + case pgtype.Float4OID: + var d float32 + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + return float64(d), err + } + case pgtype.Float8OID: + var d float64 + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + return d, err + } + case pgtype.Int2OID: + var d int16 + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + return int64(d), err + } + case pgtype.Int4OID: + var d int32 + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + return int64(d), err + } + case pgtype.Int8OID: + var d int64 + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + return d, err + } + case pgtype.JSONOID: + var d pgtype.JSON + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + case pgtype.JSONBOID: + var d pgtype.JSONB + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + case pgtype.OIDOID: + var d pgtype.OIDValue + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + case pgtype.TimestampOID: + var d pgtype.Timestamp + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + case pgtype.TimestamptzOID: + var d pgtype.Timestamptz + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + case pgtype.XIDOID: + var d pgtype.XID + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + default: + var d string + scanPlan := ci.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) + return d, err + } + } + } + } + + var more bool + if r.skipNext { + more = r.skipNextMore + r.skipNext = false + } else { + more = r.rows.Next() + } + + if !more { + if r.rows.Err() == nil { + return io.EOF + } else { + return r.rows.Err() + } + } + + for i, rv := range r.rows.RawValues() { + if rv != nil { + var err error + dest[i], err = r.valueFuncs[i](rv) + if err != nil { + return fmt.Errorf("convert field %d failed: %v", i, err) + } + } else { + dest[i] = nil + } + } + + return nil +} + +func valueToInterface(argsV []driver.Value) []interface{} { + args := make([]interface{}, 0, len(argsV)) + for _, v := range argsV { + if v != nil { + args = append(args, v.(interface{})) + } else { + args = append(args, nil) + } + } + return args +} + +func namedValueToInterface(argsV []driver.NamedValue) []interface{} { + args := make([]interface{}, 0, len(argsV)) + for _, v := range argsV { + if v.Value != nil { + args = append(args, v.Value.(interface{})) + } else { + args = append(args, nil) + } + } + return args +} + +type wrapTx struct { + ctx context.Context + tx pgx.Tx +} + +func (wtx wrapTx) Commit() error { return wtx.tx.Commit(wtx.ctx) } + +func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(wtx.ctx) } + +type fakeTx struct{} + +func (fakeTx) Commit() error { return nil } + +func (fakeTx) Rollback() error { return nil } + +// AcquireConn acquires a *pgx.Conn from database/sql connection pool. It must be released with ReleaseConn. +// +// In Go 1.13 this functionality has been incorporated into the standard library in the db.Conn.Raw() method. +func AcquireConn(db *sql.DB) (*pgx.Conn, error) { + var conn *pgx.Conn + ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + if conn == nil { + tx.Rollback() + return nil, ErrNotPgx + } + + fakeTxMutex.Lock() + fakeTxConns[conn] = tx + fakeTxMutex.Unlock() + + return conn, nil +} + +// ReleaseConn releases a *pgx.Conn acquired with AcquireConn. +func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { + var tx *sql.Tx + var ok bool + + if conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + conn.Close(ctx) + } + + fakeTxMutex.Lock() + tx, ok = fakeTxConns[conn] + if ok { + delete(fakeTxConns, conn) + fakeTxMutex.Unlock() + } else { + fakeTxMutex.Unlock() + return fmt.Errorf("can't release conn that is not acquired") + } + + return tx.Rollback() +} diff --git a/vendor/github.com/jackc/pgx/v4/tx.go b/vendor/github.com/jackc/pgx/v4/tx.go new file mode 100644 index 000000000..7a296f4fe --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/tx.go @@ -0,0 +1,444 @@ +package pgx + +import ( + "bytes" + "context" + "errors" + "fmt" + "strconv" + + "github.com/jackc/pgconn" +) + +type TxIsoLevel string + +// Transaction isolation levels +const ( + Serializable = TxIsoLevel("serializable") + RepeatableRead = TxIsoLevel("repeatable read") + ReadCommitted = TxIsoLevel("read committed") + ReadUncommitted = TxIsoLevel("read uncommitted") +) + +type TxAccessMode string + +// Transaction access modes +const ( + ReadWrite = TxAccessMode("read write") + ReadOnly = TxAccessMode("read only") +) + +type TxDeferrableMode string + +// Transaction deferrable modes +const ( + Deferrable = TxDeferrableMode("deferrable") + NotDeferrable = TxDeferrableMode("not deferrable") +) + +type TxOptions struct { + IsoLevel TxIsoLevel + AccessMode TxAccessMode + DeferrableMode TxDeferrableMode +} + +var emptyTxOptions TxOptions + +func (txOptions TxOptions) beginSQL() string { + if txOptions == emptyTxOptions { + return "begin" + } + buf := &bytes.Buffer{} + buf.WriteString("begin") + if txOptions.IsoLevel != "" { + fmt.Fprintf(buf, " isolation level %s", txOptions.IsoLevel) + } + if txOptions.AccessMode != "" { + fmt.Fprintf(buf, " %s", txOptions.AccessMode) + } + if txOptions.DeferrableMode != "" { + fmt.Fprintf(buf, " %s", txOptions.DeferrableMode) + } + + return buf.String() +} + +var ErrTxClosed = errors.New("tx is closed") + +// ErrTxCommitRollback occurs when an error has occurred in a transaction and +// Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but +// it is treated as ROLLBACK. +var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") + +// Begin starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no +// auto-rollback on context cancellation. +func (c *Conn) Begin(ctx context.Context) (Tx, error) { + return c.BeginTx(ctx, TxOptions{}) +} + +// BeginTx starts a transaction with txOptions determining the transaction mode. Unlike database/sql, the context only +// affects the begin command. i.e. there is no auto-rollback on context cancellation. +func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) { + _, err := c.Exec(ctx, txOptions.beginSQL()) + if err != nil { + // begin should never fail unless there is an underlying connection issue or + // a context timeout. In either case, the connection is possibly broken. + c.die(errors.New("failed to begin transaction")) + return nil, err + } + + return &dbTx{conn: c}, nil +} + +// BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns +// an error the transaction is rolled back. The context will be used when executing the transaction control statements +// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f. +func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { + return c.BeginTxFunc(ctx, TxOptions{}, f) +} + +// BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return +// an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be +// used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect +// the execution of f. +func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) { + var tx Tx + tx, err = c.BeginTx(ctx, txOptions) + if err != nil { + return err + } + defer func() { + rollbackErr := tx.Rollback(ctx) + if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) { + err = rollbackErr + } + }() + + fErr := f(tx) + if fErr != nil { + _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return + return fErr + } + + return tx.Commit(ctx) +} + +// Tx represents a database transaction. +// +// Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx +// state, to support pseudo-nested transactions with savepoints, and to allow tests to mock transactions. However, +// adding a method to an interface is technically a breaking change. If new methods are added to Conn it may be +// desirable to add them to Tx as well. Because of this the Tx interface is partially excluded from semantic version +// requirements. Methods will not be removed or changed, but new methods may be added. +type Tx interface { + // Begin starts a pseudo nested transaction. + Begin(ctx context.Context) (Tx, error) + + // BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested + // transaction will be committed. If it does then it will be rolled back. + BeginFunc(ctx context.Context, f func(Tx) error) (err error) + + // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested + // transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple + // times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then + // ErrTxCommitRollback will be returned. + Commit(ctx context.Context) error + + // Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a + // pseudo nested transaction. Rollback will return ErrTxClosed if the Tx is already closed, but is otherwise safe to + // call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will be called first in a non-error + // condition. Any other failure of a real transaction will result in the connection being closed. + Rollback(ctx context.Context) error + + CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) + SendBatch(ctx context.Context, b *Batch) BatchResults + LargeObjects() LargeObjects + + Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) + + Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) + Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) + QueryRow(ctx context.Context, sql string, args ...interface{}) Row + QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) + + // Conn returns the underlying *Conn that on which this transaction is executing. + Conn() *Conn +} + +// dbTx represents a database transaction. +// +// All dbTx methods return ErrTxClosed if Commit or Rollback has already been +// called on the dbTx. +type dbTx struct { + conn *Conn + err error + savepointNum int64 + closed bool +} + +// Begin starts a pseudo nested transaction implemented with a savepoint. +func (tx *dbTx) Begin(ctx context.Context) (Tx, error) { + if tx.closed { + return nil, ErrTxClosed + } + + tx.savepointNum++ + _, err := tx.conn.Exec(ctx, "savepoint sp_"+strconv.FormatInt(tx.savepointNum, 10)) + if err != nil { + return nil, err + } + + return &dbSavepoint{tx: tx, savepointNum: tx.savepointNum}, nil +} + +func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { + if tx.closed { + return ErrTxClosed + } + + var savepoint Tx + savepoint, err = tx.Begin(ctx) + if err != nil { + return err + } + defer func() { + rollbackErr := savepoint.Rollback(ctx) + if !(rollbackErr == nil || errors.Is(rollbackErr, ErrTxClosed)) { + err = rollbackErr + } + }() + + fErr := f(savepoint) + if fErr != nil { + _ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return + return fErr + } + + return savepoint.Commit(ctx) +} + +// Commit commits the transaction. +func (tx *dbTx) Commit(ctx context.Context) error { + if tx.closed { + return ErrTxClosed + } + + commandTag, err := tx.conn.Exec(ctx, "commit") + tx.closed = true + if err != nil { + if tx.conn.PgConn().TxStatus() != 'I' { + _ = tx.conn.Close(ctx) // already have error to return + } + return err + } + if string(commandTag) == "ROLLBACK" { + return ErrTxCommitRollback + } + + return nil +} + +// Rollback rolls back the transaction. Rollback will return ErrTxClosed if the +// Tx is already closed, but is otherwise safe to call multiple times. Hence, a +// defer tx.Rollback() is safe even if tx.Commit() will be called first in a +// non-error condition. +func (tx *dbTx) Rollback(ctx context.Context) error { + if tx.closed { + return ErrTxClosed + } + + _, err := tx.conn.Exec(ctx, "rollback") + tx.closed = true + if err != nil { + // A rollback failure leaves the connection in an undefined state + tx.conn.die(fmt.Errorf("rollback failed: %w", err)) + return err + } + + return nil +} + +// Exec delegates to the underlying *Conn +func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { + return tx.conn.Exec(ctx, sql, arguments...) +} + +// Prepare delegates to the underlying *Conn +func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { + if tx.closed { + return nil, ErrTxClosed + } + + return tx.conn.Prepare(ctx, name, sql) +} + +// Query delegates to the underlying *Conn +func (tx *dbTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { + if tx.closed { + // Because checking for errors can be deferred to the *Rows, build one with the error + err := ErrTxClosed + return &connRows{closed: true, err: err}, err + } + + return tx.conn.Query(ctx, sql, args...) +} + +// QueryRow delegates to the underlying *Conn +func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { + rows, _ := tx.Query(ctx, sql, args...) + return (*connRow)(rows.(*connRows)) +} + +// QueryFunc delegates to the underlying *Conn. +func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { + if tx.closed { + return nil, ErrTxClosed + } + + return tx.conn.QueryFunc(ctx, sql, args, scans, f) +} + +// CopyFrom delegates to the underlying *Conn +func (tx *dbTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { + if tx.closed { + return 0, ErrTxClosed + } + + return tx.conn.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +// SendBatch delegates to the underlying *Conn +func (tx *dbTx) SendBatch(ctx context.Context, b *Batch) BatchResults { + if tx.closed { + return &batchResults{err: ErrTxClosed} + } + + return tx.conn.SendBatch(ctx, b) +} + +// LargeObjects returns a LargeObjects instance for the transaction. +func (tx *dbTx) LargeObjects() LargeObjects { + return LargeObjects{tx: tx} +} + +func (tx *dbTx) Conn() *Conn { + return tx.conn +} + +// dbSavepoint represents a nested transaction implemented by a savepoint. +type dbSavepoint struct { + tx Tx + savepointNum int64 + closed bool +} + +// Begin starts a pseudo nested transaction implemented with a savepoint. +func (sp *dbSavepoint) Begin(ctx context.Context) (Tx, error) { + if sp.closed { + return nil, ErrTxClosed + } + + return sp.tx.Begin(ctx) +} + +func (sp *dbSavepoint) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { + if sp.closed { + return ErrTxClosed + } + + return sp.tx.BeginFunc(ctx, f) +} + +// Commit releases the savepoint essentially committing the pseudo nested transaction. +func (sp *dbSavepoint) Commit(ctx context.Context) error { + if sp.closed { + return ErrTxClosed + } + + _, err := sp.Exec(ctx, "release savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10)) + sp.closed = true + return err +} + +// Rollback rolls back to the savepoint essentially rolling back the pseudo nested transaction. Rollback will return +// ErrTxClosed if the dbSavepoint is already closed, but is otherwise safe to call multiple times. Hence, a defer sp.Rollback() +// is safe even if sp.Commit() will be called first in a non-error condition. +func (sp *dbSavepoint) Rollback(ctx context.Context) error { + if sp.closed { + return ErrTxClosed + } + + _, err := sp.Exec(ctx, "rollback to savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10)) + sp.closed = true + return err +} + +// Exec delegates to the underlying Tx +func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { + if sp.closed { + return nil, ErrTxClosed + } + + return sp.tx.Exec(ctx, sql, arguments...) +} + +// Prepare delegates to the underlying Tx +func (sp *dbSavepoint) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { + if sp.closed { + return nil, ErrTxClosed + } + + return sp.tx.Prepare(ctx, name, sql) +} + +// Query delegates to the underlying Tx +func (sp *dbSavepoint) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { + if sp.closed { + // Because checking for errors can be deferred to the *Rows, build one with the error + err := ErrTxClosed + return &connRows{closed: true, err: err}, err + } + + return sp.tx.Query(ctx, sql, args...) +} + +// QueryRow delegates to the underlying Tx +func (sp *dbSavepoint) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { + rows, _ := sp.Query(ctx, sql, args...) + return (*connRow)(rows.(*connRows)) +} + +// QueryFunc delegates to the underlying Tx. +func (sp *dbSavepoint) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { + if sp.closed { + return nil, ErrTxClosed + } + + return sp.tx.QueryFunc(ctx, sql, args, scans, f) +} + +// CopyFrom delegates to the underlying *Conn +func (sp *dbSavepoint) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { + if sp.closed { + return 0, ErrTxClosed + } + + return sp.tx.CopyFrom(ctx, tableName, columnNames, rowSrc) +} + +// SendBatch delegates to the underlying *Conn +func (sp *dbSavepoint) SendBatch(ctx context.Context, b *Batch) BatchResults { + if sp.closed { + return &batchResults{err: ErrTxClosed} + } + + return sp.tx.SendBatch(ctx, b) +} + +func (sp *dbSavepoint) LargeObjects() LargeObjects { + return LargeObjects{tx: sp} +} + +func (sp *dbSavepoint) Conn() *Conn { + return sp.tx.Conn() +} diff --git a/vendor/github.com/jackc/pgx/v4/values.go b/vendor/github.com/jackc/pgx/v4/values.go new file mode 100644 index 000000000..1a9454753 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v4/values.go @@ -0,0 +1,280 @@ +package pgx + +import ( + "database/sql/driver" + "fmt" + "math" + "reflect" + "time" + + "github.com/jackc/pgio" + "github.com/jackc/pgtype" +) + +// PostgreSQL format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + +// SerializationError occurs on failure to encode or decode a value +type SerializationError string + +func (e SerializationError) Error() string { + return string(e) +} + +func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) { + if arg == nil { + return nil, nil + } + + refVal := reflect.ValueOf(arg) + if refVal.Kind() == reflect.Ptr && refVal.IsNil() { + return nil, nil + } + + switch arg := arg.(type) { + + // https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface + // []byte to database/sql instead of string. But that caused problems with the + // simple protocol because the driver.Valuer case got taken before the + // pgtype.TextEncoder case. And driver.Valuer needed to be first in the usual + // case because of https://github.com/jackc/pgx/issues/339. So instead we + // special case JSON and JSONB. + case *pgtype.JSON: + buf, err := arg.EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil + case *pgtype.JSONB: + buf, err := arg.EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil + + case driver.Valuer: + return callValuerValue(arg) + case pgtype.TextEncoder: + buf, err := arg.EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil + case float32: + return float64(arg), nil + case float64: + return arg, nil + case bool: + return arg, nil + case time.Duration: + return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil + case time.Time: + return arg, nil + case string: + return arg, nil + case []byte: + return arg, nil + case int8: + return int64(arg), nil + case int16: + return int64(arg), nil + case int32: + return int64(arg), nil + case int64: + return arg, nil + case int: + return int64(arg), nil + case uint8: + return int64(arg), nil + case uint16: + return int64(arg), nil + case uint32: + return int64(arg), nil + case uint64: + if arg > math.MaxInt64 { + return nil, fmt.Errorf("arg too big for int64: %v", arg) + } + return int64(arg), nil + case uint: + if uint64(arg) > math.MaxInt64 { + return nil, fmt.Errorf("arg too big for int64: %v", arg) + } + return int64(arg), nil + } + + if dt, found := ci.DataTypeForValue(arg); found { + v := dt.Value + err := v.Set(arg) + if err != nil { + return nil, err + } + buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + return string(buf), nil + } + + if refVal.Kind() == reflect.Ptr { + arg = refVal.Elem().Interface() + return convertSimpleArgument(ci, arg) + } + + if strippedArg, ok := stripNamedType(&refVal); ok { + return convertSimpleArgument(ci, strippedArg) + } + return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) +} + +func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32, arg interface{}) ([]byte, error) { + if arg == nil { + return pgio.AppendInt32(buf, -1), nil + } + + switch arg := arg.(type) { + case pgtype.BinaryEncoder: + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := arg.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil + case pgtype.TextEncoder: + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := arg.EncodeText(ci, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil + case string: + buf = pgio.AppendInt32(buf, int32(len(arg))) + buf = append(buf, arg...) + return buf, nil + } + + refVal := reflect.ValueOf(arg) + + if refVal.Kind() == reflect.Ptr { + if refVal.IsNil() { + return pgio.AppendInt32(buf, -1), nil + } + arg = refVal.Elem().Interface() + return encodePreparedStatementArgument(ci, buf, oid, arg) + } + + if dt, ok := ci.DataTypeForOID(oid); ok { + value := dt.Value + err := value.Set(arg) + if err != nil { + { + if arg, ok := arg.(driver.Valuer); ok { + v, err := callValuerValue(arg) + if err != nil { + return nil, err + } + return encodePreparedStatementArgument(ci, buf, oid, v) + } + } + + return nil, err + } + + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + return buf, nil + } + + if strippedArg, ok := stripNamedType(&refVal); ok { + return encodePreparedStatementArgument(ci, buf, oid, strippedArg) + } + return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) +} + +// chooseParameterFormatCode determines the correct format code for an +// argument to a prepared statement. It defaults to TextFormatCode if no +// determination can be made. +func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 { + switch arg := arg.(type) { + case pgtype.ParamFormatPreferrer: + return arg.PreferredParamFormat() + case pgtype.BinaryEncoder: + return BinaryFormatCode + case string, *string, pgtype.TextEncoder: + return TextFormatCode + } + + return ci.ParamFormatCodeForOID(oid) +} + +func stripNamedType(val *reflect.Value) (interface{}, bool) { + switch val.Kind() { + case reflect.Int: + convVal := int(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Int8: + convVal := int8(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Int16: + convVal := int16(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Int32: + convVal := int32(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Int64: + convVal := int64(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Uint: + convVal := uint(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Uint8: + convVal := uint8(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Uint16: + convVal := uint16(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Uint32: + convVal := uint32(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.Uint64: + convVal := uint64(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() + case reflect.String: + convVal := val.String() + return convVal, reflect.TypeOf(convVal) != val.Type() + } + + return nil, false +} diff --git a/vendor/github.com/uptrace/bun/.gitignore b/vendor/github.com/uptrace/bun/.gitignore new file mode 100644 index 000000000..6f7763c71 --- /dev/null +++ b/vendor/github.com/uptrace/bun/.gitignore @@ -0,0 +1,3 @@ +*.s3db +*.prof +*.test diff --git a/vendor/github.com/go-pg/pg/v10/.prettierrc b/vendor/github.com/uptrace/bun/.prettierrc.yaml index 8b7f044ad..decea5634 100644 --- a/vendor/github.com/go-pg/pg/v10/.prettierrc +++ b/vendor/github.com/uptrace/bun/.prettierrc.yaml @@ -1,3 +1,5 @@ +trailingComma: all +tabWidth: 2 semi: false singleQuote: true proseWrap: always diff --git a/vendor/github.com/uptrace/bun/CHANGELOG.md b/vendor/github.com/uptrace/bun/CHANGELOG.md new file mode 100644 index 000000000..01bf6ba31 --- /dev/null +++ b/vendor/github.com/uptrace/bun/CHANGELOG.md @@ -0,0 +1,99 @@ +# Changelog + +## v0.4.1 - Aug 18 2021 + +- Fixed migrate package to properly rollback migrations. +- Added `allowzero` tag option that undoes `nullzero` option. + +## v0.4.0 - Aug 11 2021 + +- Changed `WhereGroup` function to accept `*SelectQuery`. +- Fixed query hooks for count queries. + +## v0.3.4 - Jul 19 2021 + +- Renamed `migrate.CreateGo` to `CreateGoMigration`. +- Added `migrate.WithPackageName` to customize the Go package name in generated migrations. +- Renamed `migrate.CreateSQL` to `CreateSQLMigrations` and changed `CreateSQLMigrations` to create + both up and down migration files. + +## v0.3.1 - Jul 12 2021 + +- Renamed `alias` field struct tag to `alt` so it is not confused with column alias. +- Reworked migrate package API. See + [migrate](https://github.com/uptrace/bun/tree/master/example/migrate) example for details. + +## v0.3.0 - Jul 09 2021 + +- Changed migrate package to return structured data instead of logging the progress. See + [migrate](https://github.com/uptrace/bun/tree/master/example/migrate) example for details. + +## v0.2.14 - Jul 01 2021 + +- Added [sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) by + [Ivan Trubach](https://github.com/tie). +- Added support for MySQL 5.7 in addition to MySQL 8. + +## v0.2.12 - Jun 29 2021 + +- Fixed scanners for net.IP and net.IPNet. + +## v0.2.10 - Jun 29 2021 + +- Fixed pgdriver to format passed query args. + +## v0.2.9 - Jun 27 2021 + +- Added support for prepared statements in pgdriver. + +## v0.2.7 - Jun 26 2021 + +- Added `UpdateQuery.Bulk` helper to generate bulk-update queries. + + Before: + + ```go + models := []Model{ + {42, "hello"}, + {43, "world"}, + } + return db.NewUpdate(). + With("_data", db.NewValues(&models)). + Model(&models). + Table("_data"). + Set("model.str = _data.str"). + Where("model.id = _data.id") + ``` + + Now: + + ```go + db.NewUpdate(). + Model(&models). + Bulk() + ``` + +## v0.2.5 - Jun 25 2021 + +- Changed time.Time to always append zero time as `NULL`. +- Added `db.RunInTx` helper. + +## v0.2.4 - Jun 21 2021 + +- Added SSL support to pgdriver. + +## v0.2.3 - Jun 20 2021 + +- Replaced `ForceDelete(ctx)` with `ForceDelete().Exec(ctx)` for soft deletes. + +## v0.2.1 - Jun 17 2021 + +- Renamed `DBI` to `IConn`. `IConn` is a common interface for `*sql.DB`, `*sql.Conn`, and `*sql.Tx`. +- Added `IDB`. `IDB` is a common interface for `*bun.DB`, `bun.Conn`, and `bun.Tx`. + +## v0.2.0 - Jun 16 2021 + +- Changed [model hooks](https://bun.uptrace.dev/guide/hooks.html#model-hooks). See + [model-hooks](example/model-hooks) example. +- Renamed `has-one` to `belongs-to`. Renamed `belongs-to` to `has-one`. Previously Bun used + incorrect names for these relations. diff --git a/vendor/github.com/go-pg/pg/extra/pgdebug/LICENSE b/vendor/github.com/uptrace/bun/LICENSE index 7751509b8..7ec81810c 100644 --- a/vendor/github.com/go-pg/pg/extra/pgdebug/LICENSE +++ b/vendor/github.com/uptrace/bun/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013 github.com/go-pg/pg Authors. All rights reserved. +Copyright (c) 2021 Vladimir Mihailenco. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/github.com/uptrace/bun/Makefile b/vendor/github.com/uptrace/bun/Makefile new file mode 100644 index 000000000..54744c617 --- /dev/null +++ b/vendor/github.com/uptrace/bun/Makefile @@ -0,0 +1,21 @@ +ALL_GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | sort) + +test: + set -e; for dir in $(ALL_GO_MOD_DIRS); do \ + echo "go test in $${dir}"; \ + (cd "$${dir}" && \ + go test ./... && \ + go vet); \ + done + +go_mod_tidy: + set -e; for dir in $(ALL_GO_MOD_DIRS); do \ + echo "go mod tidy in $${dir}"; \ + (cd "$${dir}" && \ + go get -d ./... && \ + go mod tidy); \ + done + +fmt: + gofmt -w -s ./ + goimports -w -local github.com/uptrace/bun ./ diff --git a/vendor/github.com/uptrace/bun/README.md b/vendor/github.com/uptrace/bun/README.md new file mode 100644 index 000000000..e7cc77a60 --- /dev/null +++ b/vendor/github.com/uptrace/bun/README.md @@ -0,0 +1,267 @@ +<p align="center"> + <a href="https://uptrace.dev/?utm_source=gh-redis&utm_campaign=gh-redis-banner1"> + <img src="https://raw.githubusercontent.com/uptrace/roadmap/master/banner1.png" alt="All-in-one tool to optimize performance and monitor errors & logs"> + </a> +</p> + +# Simple and performant SQL database client + +[](https://github.com/uptrace/bun/actions) +[](https://pkg.go.dev/github.com/uptrace/bun) +[](https://bun.uptrace.dev/) +[](https://discord.gg/rWtp5Aj) + +Main features are: + +- Works with [PostgreSQL](https://bun.uptrace.dev/guide/drivers.html#postgresql), + [MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql), + [SQLite](https://bun.uptrace.dev/guide/drivers.html#sqlite). +- [Selecting](/example/basic/) into a map, struct, slice of maps/structs/vars. +- [Bulk inserts](https://bun.uptrace.dev/guide/queries.html#insert). +- [Bulk updates](https://bun.uptrace.dev/guide/queries.html#update) using common table expressions. +- [Bulk deletes](https://bun.uptrace.dev/guide/queries.html#delete). +- [Fixtures](https://bun.uptrace.dev/guide/fixtures.html). +- [Migrations](https://bun.uptrace.dev/guide/migrations.html). +- [Soft deletes](https://bun.uptrace.dev/guide/soft-deletes.html). + +Resources: + +- [Examples](https://github.com/uptrace/bun/tree/master/example) +- [Documentation](https://bun.uptrace.dev/) +- [Reference](https://pkg.go.dev/github.com/uptrace/bun) +- [Starter kit](https://github.com/go-bun/bun-starter-kit) +- [RealWorld app](https://github.com/go-bun/bun-realworld-app) + +<details> + <summary>github.com/frederikhors/orm-benchmark results</summary> + +``` + 4000 times - Insert + raw_stmt: 0.38s 94280 ns/op 718 B/op 14 allocs/op + raw: 0.39s 96719 ns/op 718 B/op 13 allocs/op + beego_orm: 0.48s 118994 ns/op 2411 B/op 56 allocs/op + bun: 0.57s 142285 ns/op 918 B/op 12 allocs/op + pg: 0.58s 145496 ns/op 1235 B/op 12 allocs/op + gorm: 0.70s 175294 ns/op 6665 B/op 88 allocs/op + xorm: 0.76s 189533 ns/op 3032 B/op 94 allocs/op + + 4000 times - MultiInsert 100 row + raw: 4.59s 1147385 ns/op 135155 B/op 916 allocs/op + raw_stmt: 4.59s 1148137 ns/op 131076 B/op 916 allocs/op + beego_orm: 5.50s 1375637 ns/op 179962 B/op 2747 allocs/op + bun: 6.18s 1544648 ns/op 4265 B/op 214 allocs/op + pg: 7.01s 1753495 ns/op 5039 B/op 114 allocs/op + gorm: 9.52s 2379219 ns/op 293956 B/op 3729 allocs/op + xorm: 11.66s 2915478 ns/op 286140 B/op 7422 allocs/op + + 4000 times - Update + raw_stmt: 0.26s 65781 ns/op 773 B/op 14 allocs/op + raw: 0.31s 77209 ns/op 757 B/op 13 allocs/op + beego_orm: 0.43s 107064 ns/op 1802 B/op 47 allocs/op + bun: 0.56s 139839 ns/op 589 B/op 4 allocs/op + pg: 0.60s 149608 ns/op 896 B/op 11 allocs/op + gorm: 0.74s 185970 ns/op 6604 B/op 81 allocs/op + xorm: 0.81s 203240 ns/op 2994 B/op 119 allocs/op + + 4000 times - Read + raw: 0.33s 81671 ns/op 2081 B/op 49 allocs/op + raw_stmt: 0.34s 85847 ns/op 2112 B/op 50 allocs/op + beego_orm: 0.38s 94777 ns/op 2106 B/op 75 allocs/op + pg: 0.42s 106148 ns/op 1526 B/op 22 allocs/op + bun: 0.43s 106904 ns/op 1319 B/op 18 allocs/op + gorm: 0.65s 162221 ns/op 5240 B/op 108 allocs/op + xorm: 1.13s 281738 ns/op 8326 B/op 237 allocs/op + + 4000 times - MultiRead limit 100 + raw: 1.52s 380351 ns/op 38356 B/op 1037 allocs/op + raw_stmt: 1.54s 385541 ns/op 38388 B/op 1038 allocs/op + pg: 1.86s 465468 ns/op 24045 B/op 631 allocs/op + bun: 2.58s 645354 ns/op 30009 B/op 1122 allocs/op + beego_orm: 2.93s 732028 ns/op 55280 B/op 3077 allocs/op + gorm: 4.97s 1241831 ns/op 71628 B/op 3877 allocs/op + xorm: doesn't work +``` + +</details> + +## Installation + +```go +go get github.com/uptrace/bun +``` + +You also need to install a database/sql driver and the corresponding Bun +[dialect](https://bun.uptrace.dev/guide/drivers.html). + +## Quickstart + +First you need to create a `sql.DB`. Here we are using the +[sqliteshim](https://pkg.go.dev/github.com/uptrace/bun/driver/sqliteshim) driver which choses +between [modernc.org/sqlite](https://modernc.org/sqlite/) and +[mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) depending on your platform. + +```go +import "github.com/uptrace/bun/driver/sqliteshim" + +sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared") +if err != nil { + panic(err) +} +``` + +And then create a `bun.DB` on top of it using the corresponding SQLite dialect: + +```go +import ( + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" +) + +db := bun.NewDB(sqldb, sqlitedialect.New()) +``` + +Now you are ready to issue some queries: + +```go +type User struct { + ID int64 + Name string +} + +user := new(User) +err := db.NewSelect(). + Model(user). + Where("name != ?", ""). + OrderExpr("id ASC"). + Limit(1). + Scan(ctx) +``` + +The code above is equivalent to: + +```go +query := "SELECT id, name FROM users AS user WHERE name != '' ORDER BY id ASC LIMIT 1" + +rows, err := sqldb.QueryContext(ctx, query) +if err != nil { + panic(err) +} + +if !rows.Next() { + panic(sql.ErrNoRows) +} + +user := new(User) +if err := db.ScanRow(ctx, rows, user); err != nil { + panic(err) +} + +if err := rows.Err(); err != nil { + panic(err) +} +``` + +## Basic example + +To provide initial data for our [example](/example/basic/), we will use Bun +[fixtures](https://bun.uptrace.dev/guide/fixtures.html): + +```go +import "github.com/uptrace/bun/dbfixture" + +// Register models for the fixture. +db.RegisterModel((*User)(nil), (*Story)(nil)) + +// WithRecreateTables tells Bun to drop existing tables and create new ones. +fixture := dbfixture.New(db, dbfixture.WithRecreateTables()) + +// Load fixture.yaml which contains data for User and Story models. +if err := fixture.Load(ctx, os.DirFS("."), "fixture.yaml"); err != nil { + panic(err) +} +``` + +The `fixture.yaml` looks like this: + +```yaml +- model: User + rows: + - _id: admin + name: admin + emails: ['admin1@admin', 'admin2@admin'] + - _id: root + name: root + emails: ['root1@root', 'root2@root'] + +- model: Story + rows: + - title: Cool story + author_id: '{{ $.User.admin.ID }}' +``` + +To select all users: + +```go +users := make([]User, 0) +if err := db.NewSelect().Model(&users).OrderExpr("id ASC").Scan(ctx); err != nil { + panic(err) +} +``` + +To select a single user by id: + +```go +user1 := new(User) +if err := db.NewSelect().Model(user1).Where("id = ?", 1).Scan(ctx); err != nil { + panic(err) +} +``` + +To select a story and the associated author in a single query: + +```go +story := new(Story) +if err := db.NewSelect(). + Model(story). + Relation("Author"). + Limit(1). + Scan(ctx); err != nil { + panic(err) +} +``` + +To select a user into a map: + +```go +m := make(map[string]interface{}) +if err := db.NewSelect(). + Model((*User)(nil)). + Limit(1). + Scan(ctx, &m); err != nil { + panic(err) +} +``` + +To select all users scanning each column into a separate slice: + +```go +var ids []int64 +var names []string +if err := db.NewSelect(). + ColumnExpr("id, name"). + Model((*User)(nil)). + OrderExpr("id ASC"). + Scan(ctx, &ids, &names); err != nil { + panic(err) +} +``` + +For more details, please consult [docs](https://bun.uptrace.dev/) and check [examples](example). + +## Contributors + +Thanks to all the people who already contributed! + +<a href="https://github.com/uptrace/bun/graphs/contributors"> + <img src="https://contributors-img.web.app/image?repo=uptrace/bun" /> +</a> diff --git a/vendor/github.com/uptrace/bun/RELEASING.md b/vendor/github.com/uptrace/bun/RELEASING.md new file mode 100644 index 000000000..9e50c1063 --- /dev/null +++ b/vendor/github.com/uptrace/bun/RELEASING.md @@ -0,0 +1,21 @@ +# Releasing + +1. Run `release.sh` script which updates versions in go.mod files and pushes a new branch to GitHub: + +```shell +./scripts/release.sh -t v1.0.0 +``` + +2. Open a pull request and wait for the build to finish. + +3. Merge the pull request and run `tag.sh` to create tags for packages: + +```shell +./scripts/tag.sh -t v1.0.0 +``` + +4. Push the tags: + +```shell +git push origin --tags +``` diff --git a/vendor/github.com/uptrace/bun/bun.go b/vendor/github.com/uptrace/bun/bun.go new file mode 100644 index 000000000..92ebe691a --- /dev/null +++ b/vendor/github.com/uptrace/bun/bun.go @@ -0,0 +1,122 @@ +package bun + +import ( + "context" + "fmt" + "reflect" + + "github.com/uptrace/bun/schema" +) + +type ( + Safe = schema.Safe + Ident = schema.Ident +) + +type NullTime = schema.NullTime + +type BaseModel = schema.BaseModel + +type ( + BeforeScanHook = schema.BeforeScanHook + AfterScanHook = schema.AfterScanHook +) + +type BeforeSelectHook interface { + BeforeSelect(ctx context.Context, query *SelectQuery) error +} + +type AfterSelectHook interface { + AfterSelect(ctx context.Context, query *SelectQuery) error +} + +type BeforeInsertHook interface { + BeforeInsert(ctx context.Context, query *InsertQuery) error +} + +type AfterInsertHook interface { + AfterInsert(ctx context.Context, query *InsertQuery) error +} + +type BeforeUpdateHook interface { + BeforeUpdate(ctx context.Context, query *UpdateQuery) error +} + +type AfterUpdateHook interface { + AfterUpdate(ctx context.Context, query *UpdateQuery) error +} + +type BeforeDeleteHook interface { + BeforeDelete(ctx context.Context, query *DeleteQuery) error +} + +type AfterDeleteHook interface { + AfterDelete(ctx context.Context, query *DeleteQuery) error +} + +type BeforeCreateTableHook interface { + BeforeCreateTable(ctx context.Context, query *CreateTableQuery) error +} + +type AfterCreateTableHook interface { + AfterCreateTable(ctx context.Context, query *CreateTableQuery) error +} + +type BeforeDropTableHook interface { + BeforeDropTable(ctx context.Context, query *DropTableQuery) error +} + +type AfterDropTableHook interface { + AfterDropTable(ctx context.Context, query *DropTableQuery) error +} + +//------------------------------------------------------------------------------ + +type InValues struct { + slice reflect.Value + err error +} + +var _ schema.QueryAppender = InValues{} + +func In(slice interface{}) InValues { + v := reflect.ValueOf(slice) + if v.Kind() != reflect.Slice { + return InValues{ + err: fmt.Errorf("bun: In(non-slice %T)", slice), + } + } + return InValues{ + slice: v, + } +} + +func (in InValues) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if in.err != nil { + return nil, in.err + } + return appendIn(fmter, b, in.slice), nil +} + +func appendIn(fmter schema.Formatter, b []byte, slice reflect.Value) []byte { + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, ", "...) + } + + elem := slice.Index(i) + if elem.Kind() == reflect.Interface { + elem = elem.Elem() + } + + if elem.Kind() == reflect.Slice { + b = append(b, '(') + b = appendIn(fmter, b, elem) + b = append(b, ')') + } else { + b = fmter.AppendValue(b, elem) + } + } + return b +} diff --git a/vendor/github.com/uptrace/bun/db.go b/vendor/github.com/uptrace/bun/db.go new file mode 100644 index 000000000..d08adefb5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/db.go @@ -0,0 +1,502 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "sync/atomic" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +const ( + discardUnknownColumns internal.Flag = 1 << iota +) + +type DBStats struct { + Queries uint64 + Errors uint64 +} + +type DBOption func(db *DB) + +func WithDiscardUnknownColumns() DBOption { + return func(db *DB) { + db.flags = db.flags.Set(discardUnknownColumns) + } +} + +type DB struct { + *sql.DB + dialect schema.Dialect + features feature.Feature + + queryHooks []QueryHook + + fmter schema.Formatter + flags internal.Flag + + stats DBStats +} + +func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { + dialect.Init(sqldb) + + db := &DB{ + DB: sqldb, + dialect: dialect, + features: dialect.Features(), + fmter: schema.NewFormatter(dialect), + } + + for _, opt := range opts { + opt(db) + } + + return db +} + +func (db *DB) String() string { + var b strings.Builder + b.WriteString("DB<dialect=") + b.WriteString(db.dialect.Name().String()) + b.WriteString(">") + return b.String() +} + +func (db *DB) DBStats() DBStats { + return DBStats{ + Queries: atomic.LoadUint64(&db.stats.Queries), + Errors: atomic.LoadUint64(&db.stats.Errors), + } +} + +func (db *DB) NewValues(model interface{}) *ValuesQuery { + return NewValuesQuery(db, model) +} + +func (db *DB) NewSelect() *SelectQuery { + return NewSelectQuery(db) +} + +func (db *DB) NewInsert() *InsertQuery { + return NewInsertQuery(db) +} + +func (db *DB) NewUpdate() *UpdateQuery { + return NewUpdateQuery(db) +} + +func (db *DB) NewDelete() *DeleteQuery { + return NewDeleteQuery(db) +} + +func (db *DB) NewCreateTable() *CreateTableQuery { + return NewCreateTableQuery(db) +} + +func (db *DB) NewDropTable() *DropTableQuery { + return NewDropTableQuery(db) +} + +func (db *DB) NewCreateIndex() *CreateIndexQuery { + return NewCreateIndexQuery(db) +} + +func (db *DB) NewDropIndex() *DropIndexQuery { + return NewDropIndexQuery(db) +} + +func (db *DB) NewTruncateTable() *TruncateTableQuery { + return NewTruncateTableQuery(db) +} + +func (db *DB) NewAddColumn() *AddColumnQuery { + return NewAddColumnQuery(db) +} + +func (db *DB) NewDropColumn() *DropColumnQuery { + return NewDropColumnQuery(db) +} + +func (db *DB) ResetModel(ctx context.Context, models ...interface{}) error { + for _, model := range models { + if _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx); err != nil { + return err + } + if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil { + return err + } + } + return nil +} + +func (db *DB) Dialect() schema.Dialect { + return db.dialect +} + +func (db *DB) ScanRows(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { + model, err := newModel(db, dest) + if err != nil { + return err + } + + _, err = model.ScanRows(ctx, rows) + return err +} + +func (db *DB) ScanRow(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { + model, err := newModel(db, dest) + if err != nil { + return err + } + + rs, ok := model.(rowScanner) + if !ok { + return fmt.Errorf("bun: %T does not support ScanRow", model) + } + + return rs.ScanRow(ctx, rows) +} + +func (db *DB) AddQueryHook(hook QueryHook) { + db.queryHooks = append(db.queryHooks, hook) +} + +func (db *DB) Table(typ reflect.Type) *schema.Table { + return db.dialect.Tables().Get(typ) +} + +func (db *DB) RegisterModel(models ...interface{}) { + db.dialect.Tables().Register(models...) +} + +func (db *DB) clone() *DB { + clone := *db + + l := len(clone.queryHooks) + clone.queryHooks = clone.queryHooks[:l:l] + + return &clone +} + +func (db *DB) WithNamedArg(name string, value interface{}) *DB { + clone := db.clone() + clone.fmter = clone.fmter.WithNamedArg(name, value) + return clone +} + +func (db *DB) Formatter() schema.Formatter { + return db.fmter +} + +//------------------------------------------------------------------------------ + +func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + return db.ExecContext(context.Background(), query, args...) +} + +func (db *DB) ExecContext( + ctx context.Context, query string, args ...interface{}, +) (sql.Result, error) { + ctx, event := db.beforeQuery(ctx, nil, query, args) + res, err := db.DB.ExecContext(ctx, db.format(query, args)) + db.afterQuery(ctx, event, res, err) + return res, err +} + +func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} + +func (db *DB) QueryContext( + ctx context.Context, query string, args ...interface{}, +) (*sql.Rows, error) { + ctx, event := db.beforeQuery(ctx, nil, query, args) + rows, err := db.DB.QueryContext(ctx, db.format(query, args)) + db.afterQuery(ctx, event, nil, err) + return rows, err +} + +func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row { + return db.QueryRowContext(context.Background(), query, args...) +} + +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + ctx, event := db.beforeQuery(ctx, nil, query, args) + row := db.DB.QueryRowContext(ctx, db.format(query, args)) + db.afterQuery(ctx, event, nil, row.Err()) + return row +} + +func (db *DB) format(query string, args []interface{}) string { + return db.fmter.FormatQuery(query, args...) +} + +//------------------------------------------------------------------------------ + +type Conn struct { + db *DB + *sql.Conn +} + +func (db *DB) Conn(ctx context.Context) (Conn, error) { + conn, err := db.DB.Conn(ctx) + if err != nil { + return Conn{}, err + } + return Conn{ + db: db, + Conn: conn, + }, nil +} + +func (c Conn) ExecContext( + ctx context.Context, query string, args ...interface{}, +) (sql.Result, error) { + ctx, event := c.db.beforeQuery(ctx, nil, query, args) + res, err := c.Conn.ExecContext(ctx, c.db.format(query, args)) + c.db.afterQuery(ctx, event, res, err) + return res, err +} + +func (c Conn) QueryContext( + ctx context.Context, query string, args ...interface{}, +) (*sql.Rows, error) { + ctx, event := c.db.beforeQuery(ctx, nil, query, args) + rows, err := c.Conn.QueryContext(ctx, c.db.format(query, args)) + c.db.afterQuery(ctx, event, nil, err) + return rows, err +} + +func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + ctx, event := c.db.beforeQuery(ctx, nil, query, args) + row := c.Conn.QueryRowContext(ctx, c.db.format(query, args)) + c.db.afterQuery(ctx, event, nil, row.Err()) + return row +} + +func (c Conn) NewValues(model interface{}) *ValuesQuery { + return NewValuesQuery(c.db, model).Conn(c) +} + +func (c Conn) NewSelect() *SelectQuery { + return NewSelectQuery(c.db).Conn(c) +} + +func (c Conn) NewInsert() *InsertQuery { + return NewInsertQuery(c.db).Conn(c) +} + +func (c Conn) NewUpdate() *UpdateQuery { + return NewUpdateQuery(c.db).Conn(c) +} + +func (c Conn) NewDelete() *DeleteQuery { + return NewDeleteQuery(c.db).Conn(c) +} + +func (c Conn) NewCreateTable() *CreateTableQuery { + return NewCreateTableQuery(c.db).Conn(c) +} + +func (c Conn) NewDropTable() *DropTableQuery { + return NewDropTableQuery(c.db).Conn(c) +} + +func (c Conn) NewCreateIndex() *CreateIndexQuery { + return NewCreateIndexQuery(c.db).Conn(c) +} + +func (c Conn) NewDropIndex() *DropIndexQuery { + return NewDropIndexQuery(c.db).Conn(c) +} + +func (c Conn) NewTruncateTable() *TruncateTableQuery { + return NewTruncateTableQuery(c.db).Conn(c) +} + +func (c Conn) NewAddColumn() *AddColumnQuery { + return NewAddColumnQuery(c.db).Conn(c) +} + +func (c Conn) NewDropColumn() *DropColumnQuery { + return NewDropColumnQuery(c.db).Conn(c) +} + +//------------------------------------------------------------------------------ + +type Stmt struct { + *sql.Stmt +} + +func (db *DB) Prepare(query string) (Stmt, error) { + return db.PrepareContext(context.Background(), query) +} + +func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) { + stmt, err := db.DB.PrepareContext(ctx, query) + if err != nil { + return Stmt{}, err + } + return Stmt{Stmt: stmt}, nil +} + +//------------------------------------------------------------------------------ + +type Tx struct { + db *DB + *sql.Tx +} + +// RunInTx runs the function in a transaction. If the function returns an error, +// the transaction is rolled back. Otherwise, the transaction is committed. +func (db *DB) RunInTx( + ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, +) error { + tx, err := db.BeginTx(ctx, opts) + if err != nil { + return err + } + defer tx.Rollback() //nolint:errcheck + + if err := fn(ctx, tx); err != nil { + return err + } + return tx.Commit() +} + +func (db *DB) Begin() (Tx, error) { + return db.BeginTx(context.Background(), nil) +} + +func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { + tx, err := db.DB.BeginTx(ctx, opts) + if err != nil { + return Tx{}, err + } + return Tx{ + db: db, + Tx: tx, + }, nil +} + +func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { + return tx.ExecContext(context.TODO(), query, args...) +} + +func (tx Tx) ExecContext( + ctx context.Context, query string, args ...interface{}, +) (sql.Result, error) { + ctx, event := tx.db.beforeQuery(ctx, nil, query, args) + res, err := tx.Tx.ExecContext(ctx, tx.db.format(query, args)) + tx.db.afterQuery(ctx, event, res, err) + return res, err +} + +func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { + return tx.QueryContext(context.TODO(), query, args...) +} + +func (tx Tx) QueryContext( + ctx context.Context, query string, args ...interface{}, +) (*sql.Rows, error) { + ctx, event := tx.db.beforeQuery(ctx, nil, query, args) + rows, err := tx.Tx.QueryContext(ctx, tx.db.format(query, args)) + tx.db.afterQuery(ctx, event, nil, err) + return rows, err +} + +func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row { + return tx.QueryRowContext(context.TODO(), query, args...) +} + +func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + ctx, event := tx.db.beforeQuery(ctx, nil, query, args) + row := tx.Tx.QueryRowContext(ctx, tx.db.format(query, args)) + tx.db.afterQuery(ctx, event, nil, row.Err()) + return row +} + +//------------------------------------------------------------------------------ + +func (tx Tx) NewValues(model interface{}) *ValuesQuery { + return NewValuesQuery(tx.db, model).Conn(tx) +} + +func (tx Tx) NewSelect() *SelectQuery { + return NewSelectQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewInsert() *InsertQuery { + return NewInsertQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewUpdate() *UpdateQuery { + return NewUpdateQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDelete() *DeleteQuery { + return NewDeleteQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewCreateTable() *CreateTableQuery { + return NewCreateTableQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDropTable() *DropTableQuery { + return NewDropTableQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewCreateIndex() *CreateIndexQuery { + return NewCreateIndexQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDropIndex() *DropIndexQuery { + return NewDropIndexQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewTruncateTable() *TruncateTableQuery { + return NewTruncateTableQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewAddColumn() *AddColumnQuery { + return NewAddColumnQuery(tx.db).Conn(tx) +} + +func (tx Tx) NewDropColumn() *DropColumnQuery { + return NewDropColumnQuery(tx.db).Conn(tx) +} + +//------------------------------------------------------------------------------0 + +func (db *DB) makeQueryBytes() []byte { + // TODO: make this configurable? + return make([]byte, 0, 4096) +} + +//------------------------------------------------------------------------------ + +type result struct { + r sql.Result + n int +} + +func (r result) RowsAffected() (int64, error) { + if r.r != nil { + return r.r.RowsAffected() + } + return int64(r.n), nil +} + +func (r result) LastInsertId() (int64, error) { + if r.r != nil { + return r.r.LastInsertId() + } + return 0, errors.New("LastInsertId is not available") +} diff --git a/vendor/github.com/uptrace/bun/dialect/append.go b/vendor/github.com/uptrace/bun/dialect/append.go new file mode 100644 index 000000000..7040c5155 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/append.go @@ -0,0 +1,178 @@ +package dialect + +import ( + "encoding/hex" + "math" + "strconv" + "time" + "unicode/utf8" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/parser" +) + +func AppendError(b []byte, err error) []byte { + b = append(b, "?!("...) + b = append(b, err.Error()...) + b = append(b, ')') + return b +} + +func AppendNull(b []byte) []byte { + return append(b, "NULL"...) +} + +func AppendBool(b []byte, v bool) []byte { + if v { + return append(b, "TRUE"...) + } + return append(b, "FALSE"...) +} + +func AppendFloat32(b []byte, v float32) []byte { + return appendFloat(b, float64(v), 32) +} + +func AppendFloat64(b []byte, v float64) []byte { + return appendFloat(b, v, 64) +} + +func appendFloat(b []byte, v float64, bitSize int) []byte { + switch { + case math.IsNaN(v): + return append(b, "'NaN'"...) + case math.IsInf(v, 1): + return append(b, "'Infinity'"...) + case math.IsInf(v, -1): + return append(b, "'-Infinity'"...) + default: + return strconv.AppendFloat(b, v, 'f', -1, bitSize) + } +} + +func AppendString(b []byte, s string) []byte { + b = append(b, '\'') + for _, r := range s { + if r == '\000' { + continue + } + + if r == '\'' { + b = append(b, '\'', '\'') + continue + } + + if r < utf8.RuneSelf { + b = append(b, byte(r)) + continue + } + + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + b = append(b, '\'') + return b +} + +func AppendBytes(b []byte, bytes []byte) []byte { + if bytes == nil { + return AppendNull(b) + } + + b = append(b, `'\x`...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(bytes)))...) + hex.Encode(b[s:], bytes) + + b = append(b, '\'') + + return b +} + +func AppendTime(b []byte, tm time.Time) []byte { + if tm.IsZero() { + return AppendNull(b) + } + b = append(b, '\'') + b = tm.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00") + b = append(b, '\'') + return b +} + +func AppendJSON(b, jsonb []byte) []byte { + b = append(b, '\'') + + p := parser.New(jsonb) + for p.Valid() { + c := p.Read() + switch c { + case '"': + b = append(b, '"') + case '\'': + b = append(b, "''"...) + case '\000': + continue + case '\\': + if p.SkipBytes([]byte("u0000")) { + b = append(b, `\\u0000`...) + } else { + b = append(b, '\\') + if p.Valid() { + b = append(b, p.Read()) + } + } + default: + b = append(b, c) + } + } + + b = append(b, '\'') + + return b +} + +//------------------------------------------------------------------------------ + +func AppendIdent(b []byte, field string, quote byte) []byte { + return appendIdent(b, internal.Bytes(field), quote) +} + +func appendIdent(b, src []byte, quote byte) []byte { + var quoted bool +loop: + for _, c := range src { + switch c { + case '*': + if !quoted { + b = append(b, '*') + continue loop + } + case '.': + if quoted { + b = append(b, quote) + quoted = false + } + b = append(b, '.') + continue loop + } + + if !quoted { + b = append(b, quote) + quoted = true + } + if c == quote { + b = append(b, quote, quote) + } else { + b = append(b, c) + } + } + if quoted { + b = append(b, quote) + } + return b +} diff --git a/vendor/github.com/uptrace/bun/dialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/dialect.go new file mode 100644 index 000000000..9ff8b2461 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/dialect.go @@ -0,0 +1,26 @@ +package dialect + +type Name int + +func (n Name) String() string { + switch n { + case PG: + return "pg" + case SQLite: + return "sqlite" + case MySQL5: + return "mysql5" + case MySQL8: + return "mysql8" + default: + return "invalid" + } +} + +const ( + Invalid Name = iota + PG + SQLite + MySQL5 + MySQL8 +) diff --git a/vendor/github.com/uptrace/bun/dialect/feature/feature.go b/vendor/github.com/uptrace/bun/dialect/feature/feature.go new file mode 100644 index 000000000..ff8f1d625 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/feature/feature.go @@ -0,0 +1,22 @@ +package feature + +import "github.com/uptrace/bun/internal" + +type Feature = internal.Flag + +const DefaultFeatures = Returning | TableCascade + +const ( + Returning Feature = 1 << iota + DefaultPlaceholder + DoubleColonCast + ValuesRow + UpdateMultiTable + InsertTableAlias + DeleteTableAlias + AutoIncrement + TableCascade + TableIdentity + TableTruncate + OnDuplicateKey +) diff --git a/vendor/github.com/go-pg/pg/v10/LICENSE b/vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE index 7751509b8..7ec81810c 100644 --- a/vendor/github.com/go-pg/pg/v10/LICENSE +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013 github.com/go-pg/pg Authors. All rights reserved. +Copyright (c) 2021 Vladimir Mihailenco. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go new file mode 100644 index 000000000..475621197 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/append.go @@ -0,0 +1,303 @@ +package pgdialect + +import ( + "database/sql/driver" + "fmt" + "reflect" + "strconv" + "time" + "unicode/utf8" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/schema" +) + +var ( + driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + + stringType = reflect.TypeOf((*string)(nil)).Elem() + sliceStringType = reflect.TypeOf([]string(nil)) + + intType = reflect.TypeOf((*int)(nil)).Elem() + sliceIntType = reflect.TypeOf([]int(nil)) + + int64Type = reflect.TypeOf((*int64)(nil)).Elem() + sliceInt64Type = reflect.TypeOf([]int64(nil)) + + float64Type = reflect.TypeOf((*float64)(nil)).Elem() + sliceFloat64Type = reflect.TypeOf([]float64(nil)) +) + +func customAppender(typ reflect.Type) schema.AppenderFunc { + switch typ.Kind() { + case reflect.Uint32: + return appendUint32ValueAsInt + case reflect.Uint, reflect.Uint64: + return appendUint64ValueAsInt + } + return nil +} + +func appendUint32ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, int64(int32(v.Uint())), 10) +} + +func appendUint64ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, int64(v.Uint()), 10) +} + +//------------------------------------------------------------------------------ + +func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { + switch v := v.(type) { + case int64: + return strconv.AppendInt(b, v, 10) + case float64: + return dialect.AppendFloat64(b, v) + case bool: + return dialect.AppendBool(b, v) + case []byte: + return dialect.AppendBytes(b, v) + case string: + return arrayAppendString(b, v) + case time.Time: + return dialect.AppendTime(b, v) + default: + err := fmt.Errorf("pgdialect: can't append %T", v) + return dialect.AppendError(b, err) + } +} + +func arrayElemAppender(typ reflect.Type) schema.AppenderFunc { + if typ.Kind() == reflect.String { + return arrayAppendStringValue + } + if typ.Implements(driverValuerType) { + return arrayAppendDriverValue + } + return schema.Appender(typ, customAppender) +} + +func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + return arrayAppendString(b, v.String()) +} + +func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + iface, err := v.Interface().(driver.Valuer).Value() + if err != nil { + return dialect.AppendError(b, err) + } + return arrayAppend(fmter, b, iface) +} + +//------------------------------------------------------------------------------ + +func arrayAppender(typ reflect.Type) schema.AppenderFunc { + kind := typ.Kind() + if kind == reflect.Ptr { + typ = typ.Elem() + kind = typ.Kind() + } + + switch kind { + case reflect.Slice, reflect.Array: + // ok: + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return appendStringSliceValue + case intType: + return appendIntSliceValue + case int64Type: + return appendInt64SliceValue + case float64Type: + return appendFloat64SliceValue + } + } + + appendElem := arrayElemAppender(elemType) + if appendElem == nil { + panic(fmt.Errorf("pgdialect: %s is not supported", typ)) + } + + return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + kind := v.Kind() + switch kind { + case reflect.Ptr, reflect.Slice: + if v.IsNil() { + return dialect.AppendNull(b) + } + } + + if kind == reflect.Ptr { + v = v.Elem() + } + + b = append(b, '\'') + + b = append(b, '{') + for i := 0; i < v.Len(); i++ { + elem := v.Index(i) + b = appendElem(fmter, b, elem) + b = append(b, ',') + } + if v.Len() > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b + } +} + +func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ss := v.Convert(sliceStringType).Interface().([]string) + return appendStringSlice(b, ss) +} + +func appendStringSlice(b []byte, ss []string) []byte { + if ss == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, s := range ss { + b = arrayAppendString(b, s) + b = append(b, ',') + } + if len(ss) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceIntType).Interface().([]int) + return appendIntSlice(b, ints) +} + +func appendIntSlice(b []byte, ints []int) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, int64(n), 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ints := v.Convert(sliceInt64Type).Interface().([]int64) + return appendInt64Slice(b, ints) +} + +func appendInt64Slice(b []byte, ints []int64) []byte { + if ints == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range ints { + b = strconv.AppendInt(b, n, 10) + b = append(b, ',') + } + if len(ints) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + floats := v.Convert(sliceFloat64Type).Interface().([]float64) + return appendFloat64Slice(b, floats) +} + +func appendFloat64Slice(b []byte, floats []float64) []byte { + if floats == nil { + return dialect.AppendNull(b) + } + + b = append(b, '\'') + + b = append(b, '{') + for _, n := range floats { + b = dialect.AppendFloat64(b, n) + b = append(b, ',') + } + if len(floats) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + + b = append(b, '\'') + + return b +} + +//------------------------------------------------------------------------------ + +func arrayAppendString(b []byte, s string) []byte { + b = append(b, '"') + for _, r := range s { + switch r { + case 0: + // ignore + case '\'': + b = append(b, "'''"...) + case '"': + b = append(b, '\\', '"') + case '\\': + b = append(b, '\\', '\\') + default: + if r < utf8.RuneSelf { + b = append(b, byte(r)) + break + } + l := len(b) + if cap(b)-l < utf8.UTFMax { + b = append(b, make([]byte, utf8.UTFMax)...) + } + n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r) + b = b[:l+n] + } + } + b = append(b, '"') + return b +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go new file mode 100644 index 000000000..57f5a4384 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array.go @@ -0,0 +1,65 @@ +package pgdialect + +import ( + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/schema" +) + +type ArrayValue struct { + v reflect.Value + + append schema.AppenderFunc + scan schema.ScannerFunc +} + +// Array accepts a slice and returns a wrapper for working with PostgreSQL +// array data type. +// +// For struct fields you can use array tag: +// +// Emails []string `bun:",array"` +func Array(vi interface{}) *ArrayValue { + v := reflect.ValueOf(vi) + if !v.IsValid() { + panic(fmt.Errorf("bun: Array(nil)")) + } + + return &ArrayValue{ + v: v, + + append: arrayAppender(v.Type()), + scan: arrayScanner(v.Type()), + } +} + +var ( + _ schema.QueryAppender = (*ArrayValue)(nil) + _ sql.Scanner = (*ArrayValue)(nil) +) + +func (a *ArrayValue) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) { + if a.append == nil { + panic(fmt.Errorf("bun: Array(unsupported %s)", a.v.Type())) + } + return a.append(fmter, b, a.v), nil +} + +func (a *ArrayValue) Scan(src interface{}) error { + if a.scan == nil { + return fmt.Errorf("bun: Array(unsupported %s)", a.v.Type()) + } + if a.v.Kind() != reflect.Ptr { + return fmt.Errorf("bun: Array(non-pointer %s)", a.v.Type()) + } + return a.scan(a.v, src) +} + +func (a *ArrayValue) Value() interface{} { + if a.v.IsValid() { + return a.v.Interface() + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go new file mode 100644 index 000000000..1c927fca0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_parser.go @@ -0,0 +1,146 @@ +package pgdialect + +import ( + "bytes" + "fmt" + "io" +) + +type arrayParser struct { + b []byte + i int + + buf []byte + err error +} + +func newArrayParser(b []byte) *arrayParser { + p := &arrayParser{ + b: b, + i: 1, + } + if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' { + p.err = fmt.Errorf("bun: can't parse array: %q", b) + } + return p +} + +func (p *arrayParser) NextElem() ([]byte, error) { + if p.err != nil { + return nil, p.err + } + + c, err := p.readByte() + if err != nil { + return nil, err + } + + switch c { + case '}': + return nil, io.EOF + case '"': + b, err := p.readSubstring() + if err != nil { + return nil, err + } + + if p.peek() == ',' { + p.skipNext() + } + + return b, nil + default: + b := p.readSimple() + if bytes.Equal(b, []byte("NULL")) { + b = nil + } + + if p.peek() == ',' { + p.skipNext() + } + + return b, nil + } +} + +func (p *arrayParser) readSimple() []byte { + p.unreadByte() + + if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 { + b := p.b[p.i : p.i+i] + p.i += i + return b + } + + b := p.b[p.i : len(p.b)-1] + p.i = len(p.b) - 1 + return b +} + +func (p *arrayParser) readSubstring() ([]byte, error) { + c, err := p.readByte() + if err != nil { + return nil, err + } + + p.buf = p.buf[:0] + for { + if c == '"' { + break + } + + next, err := p.readByte() + if err != nil { + return nil, err + } + + if c == '\\' { + switch next { + case '\\', '"': + p.buf = append(p.buf, next) + + c, err = p.readByte() + if err != nil { + return nil, err + } + default: + p.buf = append(p.buf, '\\') + c = next + } + continue + } + + p.buf = append(p.buf, c) + c = next + } + + return p.buf, nil +} + +func (p *arrayParser) valid() bool { + return p.i < len(p.b) +} + +func (p *arrayParser) readByte() (byte, error) { + if p.valid() { + c := p.b[p.i] + p.i++ + return c, nil + } + return 0, io.EOF +} + +func (p *arrayParser) unreadByte() { + p.i-- +} + +func (p *arrayParser) peek() byte { + if p.valid() { + return p.b[p.i] + } + return 0 +} + +func (p *arrayParser) skipNext() { + p.i++ +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go new file mode 100644 index 000000000..33d31f325 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/array_scan.go @@ -0,0 +1,302 @@ +package pgdialect + +import ( + "fmt" + "io" + "reflect" + "strconv" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +func arrayScanner(typ reflect.Type) schema.ScannerFunc { + kind := typ.Kind() + if kind == reflect.Ptr { + typ = typ.Elem() + kind = typ.Kind() + } + + switch kind { + case reflect.Slice, reflect.Array: + // ok: + default: + return nil + } + + elemType := typ.Elem() + + if kind == reflect.Slice { + switch elemType { + case stringType: + return scanStringSliceValue + case intType: + return scanIntSliceValue + case int64Type: + return scanInt64SliceValue + case float64Type: + return scanFloat64SliceValue + } + } + + scanElem := schema.Scanner(elemType) + return func(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + kind := dest.Kind() + + if src == nil { + if kind != reflect.Slice || !dest.IsNil() { + dest.Set(reflect.Zero(dest.Type())) + } + return nil + } + + if kind == reflect.Slice { + if dest.IsNil() { + dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) + } else if dest.Len() > 0 { + dest.Set(dest.Slice(0, 0)) + } + } + + b, err := toBytes(src) + if err != nil { + return err + } + + p := newArrayParser(b) + nextValue := internal.MakeSliceNextElemFunc(dest) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return err + } + + elemValue := nextValue() + if err := scanElem(elemValue, elem); err != nil { + return err + } + } + + return nil + } +} + +func scanStringSliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeStringSlice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeStringSlice(src interface{}) ([]string, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]string, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + slice = append(slice, string(elem)) + } + + return slice, nil +} + +func scanIntSliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeIntSlice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeIntSlice(src interface{}) ([]int, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]int, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.Atoi(bytesToString(elem)) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func scanInt64SliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := decodeInt64Slice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func decodeInt64Slice(src interface{}) ([]int64, error) { + if src == nil { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]int64, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.ParseInt(bytesToString(elem), 10, 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func scanFloat64SliceValue(dest reflect.Value, src interface{}) error { + dest = reflect.Indirect(dest) + if !dest.CanSet() { + return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type()) + } + + slice, err := scanFloat64Slice(src) + if err != nil { + return err + } + + dest.Set(reflect.ValueOf(slice)) + return nil +} + +func scanFloat64Slice(src interface{}) ([]float64, error) { + if src == -1 { + return nil, nil + } + + b, err := toBytes(src) + if err != nil { + return nil, err + } + + slice := make([]float64, 0) + + p := newArrayParser(b) + for { + elem, err := p.NextElem() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + + if elem == nil { + slice = append(slice, 0) + continue + } + + n, err := strconv.ParseFloat(bytesToString(elem), 64) + if err != nil { + return nil, err + } + + slice = append(slice, n) + } + + return slice, nil +} + +func toBytes(src interface{}) ([]byte, error) { + switch src := src.(type) { + case string: + return stringToBytes(src), nil + case []byte: + return src, nil + default: + return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) + } +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go new file mode 100644 index 000000000..fb210751b --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/dialect.go @@ -0,0 +1,150 @@ +package pgdialect + +import ( + "database/sql" + "reflect" + "strconv" + "sync" + "time" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/schema" +) + +type Dialect struct { + tables *schema.Tables + features feature.Feature + + appenderMap sync.Map + scannerMap sync.Map +} + +func New() *Dialect { + d := new(Dialect) + d.tables = schema.NewTables(d) + d.features = feature.Returning | + feature.DefaultPlaceholder | + feature.DoubleColonCast | + feature.InsertTableAlias | + feature.DeleteTableAlias | + feature.TableCascade | + feature.TableIdentity | + feature.TableTruncate + return d +} + +func (d *Dialect) Init(*sql.DB) {} + +func (d *Dialect) Name() dialect.Name { + return dialect.PG +} + +func (d *Dialect) Features() feature.Feature { + return d.features +} + +func (d *Dialect) Tables() *schema.Tables { + return d.tables +} + +func (d *Dialect) OnTable(table *schema.Table) { + for _, field := range table.FieldMap { + d.onField(field) + } +} + +func (d *Dialect) onField(field *schema.Field) { + field.DiscoveredSQLType = fieldSQLType(field) + + if field.AutoIncrement { + switch field.DiscoveredSQLType { + case sqltype.SmallInt: + field.CreateTableSQLType = pgTypeSmallSerial + case sqltype.Integer: + field.CreateTableSQLType = pgTypeSerial + case sqltype.BigInt: + field.CreateTableSQLType = pgTypeBigSerial + } + } + + if field.Tag.HasOption("array") { + field.Append = arrayAppender(field.IndirectType) + field.Scan = arrayScanner(field.IndirectType) + } +} + +func (d *Dialect) IdentQuote() byte { + return '"' +} + +func (d *Dialect) Append(fmter schema.Formatter, b []byte, v interface{}) []byte { + switch v := v.(type) { + case nil: + return dialect.AppendNull(b) + case bool: + return dialect.AppendBool(b, v) + case int: + return strconv.AppendInt(b, int64(v), 10) + case int32: + return strconv.AppendInt(b, int64(v), 10) + case int64: + return strconv.AppendInt(b, v, 10) + case uint: + return strconv.AppendInt(b, int64(v), 10) + case uint32: + return strconv.AppendInt(b, int64(v), 10) + case uint64: + return strconv.AppendInt(b, int64(v), 10) + case float32: + return dialect.AppendFloat32(b, v) + case float64: + return dialect.AppendFloat64(b, v) + case string: + return dialect.AppendString(b, v) + case time.Time: + return dialect.AppendTime(b, v) + case []byte: + return dialect.AppendBytes(b, v) + case schema.QueryAppender: + return schema.AppendQueryAppender(fmter, b, v) + default: + vv := reflect.ValueOf(v) + if vv.Kind() == reflect.Ptr && vv.IsNil() { + return dialect.AppendNull(b) + } + appender := d.Appender(vv.Type()) + return appender(fmter, b, vv) + } +} + +func (d *Dialect) Appender(typ reflect.Type) schema.AppenderFunc { + if v, ok := d.appenderMap.Load(typ); ok { + return v.(schema.AppenderFunc) + } + + fn := schema.Appender(typ, customAppender) + + if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok { + return v.(schema.AppenderFunc) + } + return fn +} + +func (d *Dialect) FieldAppender(field *schema.Field) schema.AppenderFunc { + return schema.FieldAppender(d, field) +} + +func (d *Dialect) Scanner(typ reflect.Type) schema.ScannerFunc { + if v, ok := d.scannerMap.Load(typ); ok { + return v.(schema.ScannerFunc) + } + + fn := scanner(typ) + + if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok { + return v.(schema.ScannerFunc) + } + return fn +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod new file mode 100644 index 000000000..0cad1ce5b --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod @@ -0,0 +1,7 @@ +module github.com/uptrace/bun/dialect/pgdialect + +go 1.16 + +replace github.com/uptrace/bun => ../.. + +require github.com/uptrace/bun v0.4.3 diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum new file mode 100644 index 000000000..4d0f1c1bb --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.sum @@ -0,0 +1,22 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc= +github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go new file mode 100644 index 000000000..dff30b9c5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/safe.go @@ -0,0 +1,11 @@ +// +build appengine + +package pgdialect + +func bytesToString(b []byte) string { + return string(b) +} + +func stringToBytes(s string) []byte { + return []byte(s) +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go new file mode 100644 index 000000000..9e22282f5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/scan.go @@ -0,0 +1,28 @@ +package pgdialect + +import ( + "fmt" + "reflect" + + "github.com/uptrace/bun/schema" +) + +func scanner(typ reflect.Type) schema.ScannerFunc { + if typ.Kind() == reflect.Interface { + return scanInterface + } + return schema.Scanner(typ) +} + +func scanInterface(dest reflect.Value, src interface{}) error { + if dest.IsNil() { + dest.Set(reflect.ValueOf(src)) + return nil + } + + dest = dest.Elem() + if fn := scanner(dest.Type()); fn != nil { + return fn(dest, src) + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go new file mode 100644 index 000000000..4c2d8075d --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/sqltype.go @@ -0,0 +1,104 @@ +package pgdialect + +import ( + "encoding/json" + "net" + "reflect" + "time" + + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/schema" +) + +const ( + // Date / Time + pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone + pgTypeDate = "DATE" // Date + pgTypeTime = "TIME" // Time without a time zone + pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone + pgTypeInterval = "INTERVAL" // Time Interval + + // Network Addresses + pgTypeInet = "INET" // IPv4 or IPv6 hosts and networks + pgTypeCidr = "CIDR" // IPv4 or IPv6 networks + pgTypeMacaddr = "MACADDR" // MAC addresses + + // Serial Types + pgTypeSmallSerial = "SMALLSERIAL" // 2 byte autoincrementing integer + pgTypeSerial = "SERIAL" // 4 byte autoincrementing integer + pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer + + // Character Types + pgTypeChar = "CHAR" // fixed length string (blank padded) + pgTypeText = "TEXT" // variable length string without limit + + // JSON Types + pgTypeJSON = "JSON" // text representation of json data + pgTypeJSONB = "JSONB" // binary representation of json data + + // Binary Data Types + pgTypeBytea = "BYTEA" // binary string +) + +var ( + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() +) + +func fieldSQLType(field *schema.Field) string { + if field.UserSQLType != "" { + return field.UserSQLType + } + + if v, ok := field.Tag.Options["composite"]; ok { + return v + } + + if _, ok := field.Tag.Options["hstore"]; ok { + return "hstore" + } + + if _, ok := field.Tag.Options["array"]; ok { + switch field.IndirectType.Kind() { + case reflect.Slice, reflect.Array: + sqlType := sqlType(field.IndirectType.Elem()) + return sqlType + "[]" + } + } + + return sqlType(field.IndirectType) +} + +func sqlType(typ reflect.Type) string { + switch typ { + case ipType: + return pgTypeInet + case ipNetType: + return pgTypeCidr + case jsonRawMessageType: + return pgTypeJSONB + } + + sqlType := schema.DiscoverSQLType(typ) + switch sqlType { + case sqltype.Timestamp: + sqlType = pgTypeTimestampTz + } + + switch typ.Kind() { + case reflect.Map, reflect.Struct: + if sqlType == sqltype.VarChar { + return pgTypeJSONB + } + return sqlType + case reflect.Array, reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return pgTypeBytea + } + return pgTypeJSONB + } + + return sqlType +} diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go b/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go new file mode 100644 index 000000000..2a02a20b1 --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/unsafe.go @@ -0,0 +1,18 @@ +// +build !appengine + +package pgdialect + +import "unsafe" + +func bytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +func stringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go b/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go new file mode 100644 index 000000000..84a51d26d --- /dev/null +++ b/vendor/github.com/uptrace/bun/dialect/sqltype/sqltype.go @@ -0,0 +1,14 @@ +package sqltype + +const ( + Boolean = "BOOLEAN" + SmallInt = "SMALLINT" + Integer = "INTEGER" + BigInt = "BIGINT" + Real = "REAL" + DoublePrecision = "DOUBLE PRECISION" + VarChar = "VARCHAR" + Timestamp = "TIMESTAMP" + JSON = "JSON" + JSONB = "JSONB" +) diff --git a/vendor/github.com/go-pg/pg/v10/pgjson/json.go b/vendor/github.com/uptrace/bun/extra/bunjson/json.go index c401dc946..eff9d3f0e 100644 --- a/vendor/github.com/go-pg/pg/v10/pgjson/json.go +++ b/vendor/github.com/uptrace/bun/extra/bunjson/json.go @@ -1,4 +1,4 @@ -package pgjson +package bunjson import ( "encoding/json" diff --git a/vendor/github.com/go-pg/pg/v10/pgjson/provider.go b/vendor/github.com/uptrace/bun/extra/bunjson/provider.go index a4b663ce4..7f810e122 100644 --- a/vendor/github.com/go-pg/pg/v10/pgjson/provider.go +++ b/vendor/github.com/uptrace/bun/extra/bunjson/provider.go @@ -1,4 +1,4 @@ -package pgjson +package bunjson import ( "io" diff --git a/vendor/github.com/uptrace/bun/go.mod b/vendor/github.com/uptrace/bun/go.mod new file mode 100644 index 000000000..92def2a3d --- /dev/null +++ b/vendor/github.com/uptrace/bun/go.mod @@ -0,0 +1,12 @@ +module github.com/uptrace/bun + +go 1.16 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jinzhu/inflection v1.0.0 + github.com/stretchr/testify v1.7.0 + github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc + github.com/vmihailenco/msgpack/v5 v5.3.4 + golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 // indirect +) diff --git a/vendor/github.com/uptrace/bun/go.sum b/vendor/github.com/uptrace/bun/go.sum new file mode 100644 index 000000000..3bf0a4a3f --- /dev/null +++ b/vendor/github.com/uptrace/bun/go.sum @@ -0,0 +1,23 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc= +github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/uptrace/bun/hook.go b/vendor/github.com/uptrace/bun/hook.go new file mode 100644 index 000000000..4cfa68fa6 --- /dev/null +++ b/vendor/github.com/uptrace/bun/hook.go @@ -0,0 +1,98 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + "sync/atomic" + "time" + + "github.com/uptrace/bun/schema" +) + +type QueryEvent struct { + DB *DB + + QueryAppender schema.QueryAppender + Query string + QueryArgs []interface{} + + StartTime time.Time + Result sql.Result + Err error + + Stash map[interface{}]interface{} +} + +type QueryHook interface { + BeforeQuery(context.Context, *QueryEvent) context.Context + AfterQuery(context.Context, *QueryEvent) +} + +func (db *DB) beforeQuery( + ctx context.Context, + queryApp schema.QueryAppender, + query string, + queryArgs []interface{}, +) (context.Context, *QueryEvent) { + atomic.AddUint64(&db.stats.Queries, 1) + + if len(db.queryHooks) == 0 { + return ctx, nil + } + + event := &QueryEvent{ + DB: db, + + QueryAppender: queryApp, + Query: query, + QueryArgs: queryArgs, + + StartTime: time.Now(), + } + + for _, hook := range db.queryHooks { + ctx = hook.BeforeQuery(ctx, event) + } + + return ctx, event +} + +func (db *DB) afterQuery( + ctx context.Context, + event *QueryEvent, + res sql.Result, + err error, +) { + switch err { + case nil, sql.ErrNoRows: + // nothing + default: + atomic.AddUint64(&db.stats.Errors, 1) + } + + if event == nil { + return + } + + event.Result = res + event.Err = err + + db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1) +} + +func (db *DB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) { + for ; hookIndex >= 0; hookIndex-- { + db.queryHooks[hookIndex].AfterQuery(ctx, event) + } +} + +//------------------------------------------------------------------------------ + +func callBeforeScanHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(schema.BeforeScanHook).BeforeScan(ctx) +} + +func callAfterScanHook(ctx context.Context, v reflect.Value) error { + return v.Interface().(schema.AfterScanHook).AfterScan(ctx) +} diff --git a/vendor/github.com/uptrace/bun/internal/flag.go b/vendor/github.com/uptrace/bun/internal/flag.go new file mode 100644 index 000000000..b42f59df7 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/flag.go @@ -0,0 +1,16 @@ +package internal + +type Flag uint64 + +func (flag Flag) Has(other Flag) bool { + return flag&other == other +} + +func (flag Flag) Set(other Flag) Flag { + return flag | other +} + +func (flag Flag) Remove(other Flag) Flag { + flag &= ^other + return flag +} diff --git a/vendor/github.com/uptrace/bun/internal/hex.go b/vendor/github.com/uptrace/bun/internal/hex.go new file mode 100644 index 000000000..6fae2bb78 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/hex.go @@ -0,0 +1,43 @@ +package internal + +import ( + fasthex "github.com/tmthrgd/go-hex" +) + +type HexEncoder struct { + b []byte + written bool +} + +func NewHexEncoder(b []byte) *HexEncoder { + return &HexEncoder{ + b: b, + } +} + +func (enc *HexEncoder) Bytes() []byte { + return enc.b +} + +func (enc *HexEncoder) Write(b []byte) (int, error) { + if !enc.written { + enc.b = append(enc.b, '\'') + enc.b = append(enc.b, `\x`...) + enc.written = true + } + + i := len(enc.b) + enc.b = append(enc.b, make([]byte, fasthex.EncodedLen(len(b)))...) + fasthex.Encode(enc.b[i:], b) + + return len(b), nil +} + +func (enc *HexEncoder) Close() error { + if enc.written { + enc.b = append(enc.b, '\'') + } else { + enc.b = append(enc.b, "NULL"...) + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/internal/logger.go b/vendor/github.com/uptrace/bun/internal/logger.go new file mode 100644 index 000000000..2e22a0893 --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/logger.go @@ -0,0 +1,27 @@ +package internal + +import ( + "fmt" + "log" + "os" +) + +var Warn = log.New(os.Stderr, "WARN: bun: ", log.LstdFlags) + +var Deprecated = log.New(os.Stderr, "DEPRECATED: bun: ", log.LstdFlags) + +type Logging interface { + Printf(format string, v ...interface{}) +} + +type logger struct { + log *log.Logger +} + +func (l *logger) Printf(format string, v ...interface{}) { + _ = l.log.Output(2, fmt.Sprintf(format, v...)) +} + +var Logger Logging = &logger{ + log: log.New(os.Stderr, "bun: ", log.LstdFlags|log.Lshortfile), +} diff --git a/vendor/github.com/uptrace/bun/internal/map_key.go b/vendor/github.com/uptrace/bun/internal/map_key.go new file mode 100644 index 000000000..bb5fcca8c --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/map_key.go @@ -0,0 +1,67 @@ +package internal + +import "reflect" + +var ifaceType = reflect.TypeOf((*interface{})(nil)).Elem() + +type MapKey struct { + iface interface{} +} + +func NewMapKey(is []interface{}) MapKey { + return MapKey{ + iface: newMapKey(is), + } +} + +func newMapKey(is []interface{}) interface{} { + switch len(is) { + case 1: + ptr := new([1]interface{}) + copy((*ptr)[:], is) + return *ptr + case 2: + ptr := new([2]interface{}) + copy((*ptr)[:], is) + return *ptr + case 3: + ptr := new([3]interface{}) + copy((*ptr)[:], is) + return *ptr + case 4: + ptr := new([4]interface{}) + copy((*ptr)[:], is) + return *ptr + case 5: + ptr := new([5]interface{}) + copy((*ptr)[:], is) + return *ptr + case 6: + ptr := new([6]interface{}) + copy((*ptr)[:], is) + return *ptr + case 7: + ptr := new([7]interface{}) + copy((*ptr)[:], is) + return *ptr + case 8: + ptr := new([8]interface{}) + copy((*ptr)[:], is) + return *ptr + case 9: + ptr := new([9]interface{}) + copy((*ptr)[:], is) + return *ptr + case 10: + ptr := new([10]interface{}) + copy((*ptr)[:], is) + return *ptr + default: + } + + at := reflect.New(reflect.ArrayOf(len(is), ifaceType)).Elem() + for i, v := range is { + *(at.Index(i).Addr().Interface().(*interface{})) = v + } + return at.Interface() +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/parser/parser.go b/vendor/github.com/uptrace/bun/internal/parser/parser.go index f2db676c9..cdfc0be16 100644 --- a/vendor/github.com/go-pg/pg/v10/internal/parser/parser.go +++ b/vendor/github.com/uptrace/bun/internal/parser/parser.go @@ -4,7 +4,7 @@ import ( "bytes" "strconv" - "github.com/go-pg/pg/v10/internal" + "github.com/uptrace/bun/internal" ) type Parser struct { @@ -19,7 +19,7 @@ func New(b []byte) *Parser { } func NewString(s string) *Parser { - return New(internal.StringToBytes(s)) + return New(internal.Bytes(s)) } func (p *Parser) Valid() bool { @@ -88,7 +88,7 @@ func (p *Parser) ReadIdentifier() (string, bool) { if ind := bytes.IndexByte(p.b[s:], ')'); ind != -1 { b := p.b[s : s+ind] p.i = s + ind + 1 - return internal.BytesToString(b), false + return internal.String(b), false } } @@ -110,7 +110,7 @@ func (p *Parser) ReadIdentifier() (string, bool) { } b := p.b[p.i : p.i+ind] p.i += ind - return internal.BytesToString(b), !alpha + return internal.String(b), !alpha } func (p *Parser) ReadNumber() int { diff --git a/vendor/github.com/go-pg/pg/v10/internal/safe.go b/vendor/github.com/uptrace/bun/internal/safe.go index 870fe541f..862ff0eb3 100644 --- a/vendor/github.com/go-pg/pg/v10/internal/safe.go +++ b/vendor/github.com/uptrace/bun/internal/safe.go @@ -2,10 +2,10 @@ package internal -func BytesToString(b []byte) string { +func String(b []byte) string { return string(b) } -func StringToBytes(s string) []byte { +func Bytes(s string) []byte { return []byte(s) } diff --git a/vendor/github.com/uptrace/bun/internal/tagparser/parser.go b/vendor/github.com/uptrace/bun/internal/tagparser/parser.go new file mode 100644 index 000000000..8ef89248c --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/tagparser/parser.go @@ -0,0 +1,147 @@ +package tagparser + +import ( + "strings" +) + +type Tag struct { + Name string + Options map[string]string +} + +func (t Tag) HasOption(name string) bool { + _, ok := t.Options[name] + return ok +} + +func Parse(s string) Tag { + p := parser{ + s: s, + } + p.parse() + return p.tag +} + +type parser struct { + s string + i int + + tag Tag + seenName bool // for empty names +} + +func (p *parser) setName(name string) { + if p.seenName { + p.addOption(name, "") + } else { + p.seenName = true + p.tag.Name = name + } +} + +func (p *parser) addOption(key, value string) { + p.seenName = true + if key == "" { + return + } + if p.tag.Options == nil { + p.tag.Options = make(map[string]string) + } + p.tag.Options[key] = value +} + +func (p *parser) parse() { + for p.valid() { + p.parseKeyValue() + if p.peek() == ',' { + p.i++ + } + } +} + +func (p *parser) parseKeyValue() { + start := p.i + + for p.valid() { + switch c := p.read(); c { + case ',': + key := p.s[start : p.i-1] + p.setName(key) + return + case ':': + key := p.s[start : p.i-1] + value := p.parseValue() + p.addOption(key, value) + return + case '"': + key := p.parseQuotedValue() + p.setName(key) + return + } + } + + key := p.s[start:p.i] + p.setName(key) +} + +func (p *parser) parseValue() string { + start := p.i + + for p.valid() { + switch c := p.read(); c { + case '"': + return p.parseQuotedValue() + case ',': + return p.s[start : p.i-1] + } + } + + if p.i == start { + return "" + } + return p.s[start:p.i] +} + +func (p *parser) parseQuotedValue() string { + if i := strings.IndexByte(p.s[p.i:], '"'); i >= 0 && p.s[p.i+i-1] != '\\' { + s := p.s[p.i : p.i+i] + p.i += i + 1 + return s + } + + b := make([]byte, 0, 16) + + for p.valid() { + switch c := p.read(); c { + case '\\': + b = append(b, p.read()) + case '"': + return string(b) + default: + b = append(b, c) + } + } + + return "" +} + +func (p *parser) valid() bool { + return p.i < len(p.s) +} + +func (p *parser) read() byte { + if !p.valid() { + return 0 + } + c := p.s[p.i] + p.i++ + return c +} + +func (p *parser) peek() byte { + if !p.valid() { + return 0 + } + c := p.s[p.i] + return c +} diff --git a/vendor/github.com/go-pg/pg/v10/types/time.go b/vendor/github.com/uptrace/bun/internal/time.go index e68a7a19a..e4e0804b0 100644 --- a/vendor/github.com/go-pg/pg/v10/types/time.go +++ b/vendor/github.com/uptrace/bun/internal/time.go @@ -1,9 +1,8 @@ -package types +package internal import ( + "fmt" "time" - - "github.com/go-pg/pg/v10/internal" ) const ( @@ -15,13 +14,10 @@ const ( timestamptzFormat3 = "2006-01-02 15:04:05.999999999-07" ) -func ParseTime(b []byte) (time.Time, error) { - s := internal.BytesToString(b) - return ParseTimeString(s) -} - -func ParseTimeString(s string) (time.Time, error) { +func ParseTime(s string) (time.Time, error) { switch l := len(s); { + case l < len("15:04:05"): + return time.Time{}, fmt.Errorf("bun: can't parse time=%q", s) case l <= len(timeFormat): if s[2] == ':' { return time.ParseInLocation(timeFormat, s, time.UTC) @@ -43,14 +39,3 @@ func ParseTimeString(s string) (time.Time, error) { return time.ParseInLocation(timestampFormat, s, time.UTC) } } - -func AppendTime(b []byte, tm time.Time, flags int) []byte { - if flags == 1 { - b = append(b, '\'') - } - b = tm.UTC().AppendFormat(b, timestamptzFormat) - if flags == 1 { - b = append(b, '\'') - } - return b -} diff --git a/vendor/github.com/go-pg/pg/v10/internal/underscore.go b/vendor/github.com/uptrace/bun/internal/underscore.go index e71c11705..9de52fb7b 100644 --- a/vendor/github.com/go-pg/pg/v10/internal/underscore.go +++ b/vendor/github.com/uptrace/bun/internal/underscore.go @@ -65,29 +65,3 @@ func ToExported(s string) string { } return s } - -func UpperString(s string) string { - if isUpperString(s) { - return s - } - - b := make([]byte, len(s)) - for i := range b { - c := s[i] - if IsLower(c) { - c = ToUpper(c) - } - b[i] = c - } - return string(b) -} - -func isUpperString(s string) bool { - for i := 0; i < len(s); i++ { - c := s[i] - if IsLower(c) { - return false - } - } - return true -} diff --git a/vendor/github.com/uptrace/bun/internal/unsafe.go b/vendor/github.com/uptrace/bun/internal/unsafe.go new file mode 100644 index 000000000..4bc79701f --- /dev/null +++ b/vendor/github.com/uptrace/bun/internal/unsafe.go @@ -0,0 +1,20 @@ +// +build !appengine + +package internal + +import "unsafe" + +// String converts byte slice to string. +func String(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// Bytes converts string to byte slice. +func Bytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} diff --git a/vendor/github.com/go-pg/pg/v10/internal/util.go b/vendor/github.com/uptrace/bun/internal/util.go index 80ad1dd9a..c831dc659 100644 --- a/vendor/github.com/go-pg/pg/v10/internal/util.go +++ b/vendor/github.com/uptrace/bun/internal/util.go @@ -1,23 +1,9 @@ package internal import ( - "context" "reflect" - "time" ) -func Sleep(ctx context.Context, dur time.Duration) error { - t := time.NewTimer(dur) - defer t.Stop() - - select { - case <-t.C: - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - func MakeSliceNextElemFunc(v reflect.Value) func() reflect.Value { if v.Kind() == reflect.Array { var pos int diff --git a/vendor/github.com/uptrace/bun/join.go b/vendor/github.com/uptrace/bun/join.go new file mode 100644 index 000000000..4557f5bc0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/join.go @@ -0,0 +1,308 @@ +package bun + +import ( + "context" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type join struct { + Parent *join + BaseModel tableModel + JoinModel tableModel + Relation *schema.Relation + + ApplyQueryFunc func(*SelectQuery) *SelectQuery + columns []schema.QueryWithArgs +} + +func (j *join) applyQuery(q *SelectQuery) { + if j.ApplyQueryFunc == nil { + return + } + + var table *schema.Table + var columns []schema.QueryWithArgs + + // Save state. + table, q.table = q.table, j.JoinModel.Table() + columns, q.columns = q.columns, nil + + q = j.ApplyQueryFunc(q) + + // Restore state. + q.table = table + j.columns, q.columns = q.columns, columns +} + +func (j *join) Select(ctx context.Context, q *SelectQuery) error { + switch j.Relation.Type { + case schema.HasManyRelation: + return j.selectMany(ctx, q) + case schema.ManyToManyRelation: + return j.selectM2M(ctx, q) + } + panic("not reached") +} + +func (j *join) selectMany(ctx context.Context, q *SelectQuery) error { + q = j.manyQuery(q) + if q == nil { + return nil + } + return q.Scan(ctx) +} + +func (j *join) manyQuery(q *SelectQuery) *SelectQuery { + hasManyModel := newHasManyModel(j) + if hasManyModel == nil { + return nil + } + + q = q.Model(hasManyModel) + + var where []byte + if len(j.Relation.JoinFields) > 1 { + where = append(where, '(') + } + where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinFields) + if len(j.Relation.JoinFields) > 1 { + where = append(where, ')') + } + where = append(where, " IN ("...) + where = appendChildValues( + q.db.Formatter(), + where, + j.JoinModel.Root(), + j.JoinModel.ParentIndex(), + j.Relation.BaseFields, + ) + where = append(where, ")"...) + q = q.Where(internal.String(where)) + + if j.Relation.PolymorphicField != nil { + q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) + } + + j.applyQuery(q) + q = q.Apply(j.hasManyColumns) + + return q +} + +func (j *join) hasManyColumns(q *SelectQuery) *SelectQuery { + if j.Relation.M2MTable != nil { + q = q.ColumnExpr(string(j.Relation.M2MTable.SQLAlias) + ".*") + } + + b := make([]byte, 0, 32) + + if len(j.columns) > 0 { + for i, col := range j.columns { + if i > 0 { + b = append(b, ", "...) + } + + var err error + b, err = col.AppendQuery(q.db.fmter, b) + if err != nil { + q.err = err + return q + } + } + } else { + joinTable := j.JoinModel.Table() + b = appendColumns(b, joinTable.SQLAlias, joinTable.Fields) + } + + q = q.ColumnExpr(internal.String(b)) + + return q +} + +func (j *join) selectM2M(ctx context.Context, q *SelectQuery) error { + q = j.m2mQuery(q) + if q == nil { + return nil + } + return q.Scan(ctx) +} + +func (j *join) m2mQuery(q *SelectQuery) *SelectQuery { + fmter := q.db.fmter + + m2mModel := newM2MModel(j) + if m2mModel == nil { + return nil + } + q = q.Model(m2mModel) + + index := j.JoinModel.ParentIndex() + baseTable := j.BaseModel.Table() + + //nolint + var join []byte + join = append(join, "JOIN "...) + join = fmter.AppendQuery(join, string(j.Relation.M2MTable.Name)) + join = append(join, " AS "...) + join = append(join, j.Relation.M2MTable.SQLAlias...) + join = append(join, " ON ("...) + for i, col := range j.Relation.M2MBaseFields { + if i > 0 { + join = append(join, ", "...) + } + join = append(join, j.Relation.M2MTable.SQLAlias...) + join = append(join, '.') + join = append(join, col.SQLName...) + } + join = append(join, ") IN ("...) + join = appendChildValues(fmter, join, j.BaseModel.Root(), index, baseTable.PKs) + join = append(join, ")"...) + q = q.Join(internal.String(join)) + + joinTable := j.JoinModel.Table() + for i, m2mJoinField := range j.Relation.M2MJoinFields { + joinField := j.Relation.JoinFields[i] + q = q.Where("?.? = ?.?", + joinTable.SQLAlias, joinField.SQLName, + j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName) + } + + j.applyQuery(q) + q = q.Apply(j.hasManyColumns) + + return q +} + +func (j *join) hasParent() bool { + if j.Parent != nil { + switch j.Parent.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + return true + } + } + return false +} + +func (j *join) appendAlias(fmter schema.Formatter, b []byte) []byte { + quote := fmter.IdentQuote() + + b = append(b, quote) + b = appendAlias(b, j) + b = append(b, quote) + return b +} + +func (j *join) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte { + quote := fmter.IdentQuote() + + b = append(b, quote) + b = appendAlias(b, j) + b = append(b, "__"...) + b = append(b, column...) + b = append(b, quote) + return b +} + +func (j *join) appendBaseAlias(fmter schema.Formatter, b []byte) []byte { + quote := fmter.IdentQuote() + + if j.hasParent() { + b = append(b, quote) + b = appendAlias(b, j.Parent) + b = append(b, quote) + return b + } + return append(b, j.BaseModel.Table().SQLAlias...) +} + +func (j *join) appendSoftDelete(b []byte, flags internal.Flag) []byte { + b = append(b, '.') + b = append(b, j.JoinModel.Table().SoftDeleteField.SQLName...) + if flags.Has(deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + return b +} + +func appendAlias(b []byte, j *join) []byte { + if j.hasParent() { + b = appendAlias(b, j.Parent) + b = append(b, "__"...) + } + b = append(b, j.Relation.Field.Name...) + return b +} + +func (j *join) appendHasOneJoin( + fmter schema.Formatter, b []byte, q *SelectQuery, +) (_ []byte, err error) { + isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag) + + b = append(b, "LEFT JOIN "...) + b = fmter.AppendQuery(b, string(j.JoinModel.Table().SQLNameForSelects)) + b = append(b, " AS "...) + b = j.appendAlias(fmter, b) + + b = append(b, " ON "...) + + b = append(b, '(') + for i, baseField := range j.Relation.BaseFields { + if i > 0 { + b = append(b, " AND "...) + } + b = j.appendAlias(fmter, b) + b = append(b, '.') + b = append(b, j.Relation.JoinFields[i].SQLName...) + b = append(b, " = "...) + b = j.appendBaseAlias(fmter, b) + b = append(b, '.') + b = append(b, baseField.SQLName...) + } + b = append(b, ')') + + if isSoftDelete { + b = append(b, " AND "...) + b = j.appendAlias(fmter, b) + b = j.appendSoftDelete(b, q.flags) + } + + return b, nil +} + +func appendChildValues( + fmter schema.Formatter, b []byte, v reflect.Value, index []int, fields []*schema.Field, +) []byte { + seen := make(map[string]struct{}) + walk(v, index, func(v reflect.Value) { + start := len(b) + + if len(fields) > 1 { + b = append(b, '(') + } + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = f.AppendValue(fmter, b, v) + } + if len(fields) > 1 { + b = append(b, ')') + } + b = append(b, ", "...) + + if _, ok := seen[string(b[start:])]; ok { + b = b[:start] + } else { + seen[string(b[start:])] = struct{}{} + } + }) + if len(seen) > 0 { + b = b[:len(b)-2] // trim ", " + } + return b +} diff --git a/vendor/github.com/uptrace/bun/model.go b/vendor/github.com/uptrace/bun/model.go new file mode 100644 index 000000000..c9f0f3583 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model.go @@ -0,0 +1,207 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "time" + + "github.com/uptrace/bun/schema" +) + +var errNilModel = errors.New("bun: Model(nil)") + +var timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + +type Model interface { + ScanRows(ctx context.Context, rows *sql.Rows) (int, error) + Value() interface{} +} + +type rowScanner interface { + ScanRow(ctx context.Context, rows *sql.Rows) error +} + +type model interface { + Model +} + +type tableModel interface { + model + + schema.BeforeScanHook + schema.AfterScanHook + ScanColumn(column string, src interface{}) error + + Table() *schema.Table + Relation() *schema.Relation + + Join(string, func(*SelectQuery) *SelectQuery) *join + GetJoin(string) *join + GetJoins() []join + AddJoin(join) *join + + Root() reflect.Value + ParentIndex() []int + Mount(reflect.Value) + + updateSoftDeleteField() error +} + +func newModel(db *DB, dest []interface{}) (model, error) { + if len(dest) == 1 { + return _newModel(db, dest[0], true) + } + + values := make([]reflect.Value, len(dest)) + + for i, el := range dest { + v := reflect.ValueOf(el) + if v.Kind() != reflect.Ptr { + return nil, fmt.Errorf("bun: Scan(non-pointer %T)", dest) + } + + v = v.Elem() + if v.Kind() != reflect.Slice { + return newScanModel(db, dest), nil + } + + values[i] = v + } + + return newSliceModel(db, dest, values), nil +} + +func newSingleModel(db *DB, dest interface{}) (model, error) { + return _newModel(db, dest, false) +} + +func _newModel(db *DB, dest interface{}, scan bool) (model, error) { + switch dest := dest.(type) { + case nil: + return nil, errNilModel + case Model: + return dest, nil + case sql.Scanner: + if !scan { + return nil, fmt.Errorf("bun: Model(unsupported %T)", dest) + } + return newScanModel(db, []interface{}{dest}), nil + } + + v := reflect.ValueOf(dest) + if !v.IsValid() { + return nil, errNilModel + } + if v.Kind() != reflect.Ptr { + return nil, fmt.Errorf("bun: Model(non-pointer %T)", dest) + } + + if v.IsNil() { + typ := v.Type().Elem() + if typ.Kind() == reflect.Struct { + return newStructTableModel(db, dest, db.Table(typ)), nil + } + return nil, fmt.Errorf("bun: Model(nil %T)", dest) + } + + v = v.Elem() + + switch v.Kind() { + case reflect.Map: + typ := v.Type() + if err := validMap(typ); err != nil { + return nil, err + } + mapPtr := v.Addr().Interface().(*map[string]interface{}) + return newMapModel(db, mapPtr), nil + case reflect.Struct: + if v.Type() != timeType { + return newStructTableModelValue(db, dest, v), nil + } + case reflect.Slice: + switch elemType := sliceElemType(v); elemType.Kind() { + case reflect.Struct: + if elemType != timeType { + return newSliceTableModel(db, dest, v, elemType), nil + } + case reflect.Map: + if err := validMap(elemType); err != nil { + return nil, err + } + slicePtr := v.Addr().Interface().(*[]map[string]interface{}) + return newMapSliceModel(db, slicePtr), nil + } + return newSliceModel(db, []interface{}{dest}, []reflect.Value{v}), nil + } + + if scan { + return newScanModel(db, []interface{}{dest}), nil + } + + return nil, fmt.Errorf("bun: Model(unsupported %T)", dest) +} + +func newTableModelIndex( + db *DB, + table *schema.Table, + root reflect.Value, + index []int, + rel *schema.Relation, +) (tableModel, error) { + typ := typeByIndex(table.Type, index) + + if typ.Kind() == reflect.Struct { + return &structTableModel{ + db: db, + table: table.Dialect().Tables().Get(typ), + rel: rel, + + root: root, + index: index, + }, nil + } + + if typ.Kind() == reflect.Slice { + structType := indirectType(typ.Elem()) + if structType.Kind() == reflect.Struct { + m := sliceTableModel{ + structTableModel: structTableModel{ + db: db, + table: table.Dialect().Tables().Get(structType), + rel: rel, + + root: root, + index: index, + }, + } + m.init(typ) + return &m, nil + } + } + + return nil, fmt.Errorf("bun: NewModel(%s)", typ) +} + +func validMap(typ reflect.Type) error { + if typ.Key().Kind() != reflect.String || typ.Elem().Kind() != reflect.Interface { + return fmt.Errorf("bun: Model(unsupported %s) (expected *map[string]interface{})", + typ) + } + return nil +} + +//------------------------------------------------------------------------------ + +func isSingleRowModel(m model) bool { + switch m.(type) { + case *mapModel, + *structTableModel, + *scanModel: + return true + default: + return false + } +} diff --git a/vendor/github.com/uptrace/bun/model_map.go b/vendor/github.com/uptrace/bun/model_map.go new file mode 100644 index 000000000..81c1a4a3b --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_map.go @@ -0,0 +1,183 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + "sort" + + "github.com/uptrace/bun/schema" +) + +type mapModel struct { + db *DB + + dest *map[string]interface{} + m map[string]interface{} + + rows *sql.Rows + columns []string + _columnTypes []*sql.ColumnType + scanIndex int +} + +var _ model = (*mapModel)(nil) + +func newMapModel(db *DB, dest *map[string]interface{}) *mapModel { + m := &mapModel{ + db: db, + dest: dest, + } + if dest != nil { + m.m = *dest + } + return m +} + +func (m *mapModel) Value() interface{} { + return m.dest +} + +func (m *mapModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + if !rows.Next() { + return 0, rows.Err() + } + + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.rows = rows + m.columns = columns + dest := makeDest(m, len(columns)) + + if m.m == nil { + m.m = make(map[string]interface{}, len(m.columns)) + } + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + *m.dest = m.m + + return 1, nil +} + +func (m *mapModel) Scan(src interface{}) error { + if _, ok := src.([]byte); !ok { + return m.scanRaw(src) + } + + columnTypes, err := m.columnTypes() + if err != nil { + return err + } + + scanType := columnTypes[m.scanIndex].ScanType() + switch scanType.Kind() { + case reflect.Interface: + return m.scanRaw(src) + case reflect.Slice: + if scanType.Elem().Kind() == reflect.Uint8 { + return m.scanRaw(src) + } + } + + dest := reflect.New(scanType).Elem() + if err := schema.Scanner(scanType)(dest, src); err != nil { + return err + } + + return m.scanRaw(dest.Interface()) +} + +func (m *mapModel) columnTypes() ([]*sql.ColumnType, error) { + if m._columnTypes == nil { + columnTypes, err := m.rows.ColumnTypes() + if err != nil { + return nil, err + } + m._columnTypes = columnTypes + } + return m._columnTypes, nil +} + +func (m *mapModel) scanRaw(src interface{}) error { + columnName := m.columns[m.scanIndex] + m.scanIndex++ + m.m[columnName] = src + return nil +} + +func (m *mapModel) appendColumnsValues(fmter schema.Formatter, b []byte) []byte { + keys := make([]string, 0, len(m.m)) + + for k := range m.m { + keys = append(keys, k) + } + sort.Strings(keys) + + b = append(b, " ("...) + + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + b = fmter.AppendIdent(b, k) + } + + b = append(b, ") VALUES ("...) + + isTemplate := fmter.IsNop() + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + if isTemplate { + b = append(b, '?') + } else { + b = fmter.Dialect().Append(fmter, b, m.m[k]) + } + } + + b = append(b, ")"...) + + return b +} + +func (m *mapModel) appendSet(fmter schema.Formatter, b []byte) []byte { + keys := make([]string, 0, len(m.m)) + + for k := range m.m { + keys = append(keys, k) + } + sort.Strings(keys) + + isTemplate := fmter.IsNop() + for i, k := range keys { + if i > 0 { + b = append(b, ", "...) + } + + b = fmter.AppendIdent(b, k) + b = append(b, " = "...) + if isTemplate { + b = append(b, '?') + } else { + b = fmter.Dialect().Append(fmter, b, m.m[k]) + } + } + + return b +} + +func makeDest(v interface{}, n int) []interface{} { + dest := make([]interface{}, n) + for i := range dest { + dest[i] = v + } + return dest +} diff --git a/vendor/github.com/uptrace/bun/model_map_slice.go b/vendor/github.com/uptrace/bun/model_map_slice.go new file mode 100644 index 000000000..5c6f48e44 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_map_slice.go @@ -0,0 +1,162 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "sort" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/schema" +) + +type mapSliceModel struct { + mapModel + dest *[]map[string]interface{} + + keys []string +} + +var _ model = (*mapSliceModel)(nil) + +func newMapSliceModel(db *DB, dest *[]map[string]interface{}) *mapSliceModel { + return &mapSliceModel{ + mapModel: mapModel{ + db: db, + }, + dest: dest, + } +} + +func (m *mapSliceModel) Value() interface{} { + return m.dest +} + +func (m *mapSliceModel) SetCap(cap int) { + if cap > 100 { + cap = 100 + } + if slice := *m.dest; len(slice) < cap { + *m.dest = make([]map[string]interface{}, 0, cap) + } +} + +func (m *mapSliceModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.rows = rows + m.columns = columns + dest := makeDest(m, len(columns)) + + slice := *m.dest + if len(slice) > 0 { + slice = slice[:0] + } + + var n int + + for rows.Next() { + m.m = make(map[string]interface{}, len(m.columns)) + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + slice = append(slice, m.m) + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + *m.dest = slice + return n, nil +} + +func (m *mapSliceModel) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if err := m.initKeys(); err != nil { + return nil, err + } + + for i, k := range m.keys { + if i > 0 { + b = append(b, ", "...) + } + b = fmter.AppendIdent(b, k) + } + + return b, nil +} + +func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if err := m.initKeys(); err != nil { + return nil, err + } + slice := *m.dest + + b = append(b, "VALUES "...) + if m.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + + if fmter.IsNop() { + for i := range m.keys { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, '?') + } + return b, nil + } + + for i, el := range slice { + if i > 0 { + b = append(b, "), "...) + if m.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + } + + for j, key := range m.keys { + if j > 0 { + b = append(b, ", "...) + } + b = fmter.Dialect().Append(fmter, b, el[key]) + } + } + + b = append(b, ')') + + return b, nil +} + +func (m *mapSliceModel) initKeys() error { + if m.keys != nil { + return nil + } + + slice := *m.dest + if len(slice) == 0 { + return errors.New("bun: map slice is empty") + } + + first := slice[0] + keys := make([]string, 0, len(first)) + + for k := range first { + keys = append(keys, k) + } + + sort.Strings(keys) + m.keys = keys + + return nil +} diff --git a/vendor/github.com/uptrace/bun/model_scan.go b/vendor/github.com/uptrace/bun/model_scan.go new file mode 100644 index 000000000..6dd061fb2 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_scan.go @@ -0,0 +1,54 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" +) + +type scanModel struct { + db *DB + + dest []interface{} + scanIndex int +} + +var _ model = (*scanModel)(nil) + +func newScanModel(db *DB, dest []interface{}) *scanModel { + return &scanModel{ + db: db, + dest: dest, + } +} + +func (m *scanModel) Value() interface{} { + return m.dest +} + +func (m *scanModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + if !rows.Next() { + return 0, rows.Err() + } + + dest := makeDest(m, len(m.dest)) + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + return 1, nil +} + +func (m *scanModel) ScanRow(ctx context.Context, rows *sql.Rows) error { + return rows.Scan(m.dest...) +} + +func (m *scanModel) Scan(src interface{}) error { + dest := reflect.ValueOf(m.dest[m.scanIndex]) + m.scanIndex++ + + scanner := m.db.dialect.Scanner(dest.Type()) + return scanner(dest, src) +} diff --git a/vendor/github.com/uptrace/bun/model_slice.go b/vendor/github.com/uptrace/bun/model_slice.go new file mode 100644 index 000000000..afe804382 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_slice.go @@ -0,0 +1,82 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type sliceInfo struct { + nextElem func() reflect.Value + scan schema.ScannerFunc +} + +type sliceModel struct { + dest []interface{} + values []reflect.Value + scanIndex int + info []sliceInfo +} + +var _ model = (*sliceModel)(nil) + +func newSliceModel(db *DB, dest []interface{}, values []reflect.Value) *sliceModel { + return &sliceModel{ + dest: dest, + values: values, + } +} + +func (m *sliceModel) Value() interface{} { + return m.dest +} + +func (m *sliceModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.info = make([]sliceInfo, len(m.values)) + for i, v := range m.values { + if v.IsValid() && v.Len() > 0 { + v.Set(v.Slice(0, 0)) + } + + m.info[i] = sliceInfo{ + nextElem: internal.MakeSliceNextElemFunc(v), + scan: schema.Scanner(v.Type().Elem()), + } + } + + if len(columns) == 0 { + return 0, nil + } + dest := makeDest(m, len(columns)) + + var n int + + for rows.Next() { + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return 0, err + } + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +func (m *sliceModel) Scan(src interface{}) error { + info := m.info[m.scanIndex] + m.scanIndex++ + + dest := info.nextElem() + return info.scan(dest, src) +} diff --git a/vendor/github.com/uptrace/bun/model_table_has_many.go b/vendor/github.com/uptrace/bun/model_table_has_many.go new file mode 100644 index 000000000..e64b7088d --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_has_many.go @@ -0,0 +1,149 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type hasManyModel struct { + *sliceTableModel + baseTable *schema.Table + rel *schema.Relation + + baseValues map[internal.MapKey][]reflect.Value + structKey []interface{} +} + +var _ tableModel = (*hasManyModel)(nil) + +func newHasManyModel(j *join) *hasManyModel { + baseTable := j.BaseModel.Table() + joinModel := j.JoinModel.(*sliceTableModel) + baseValues := baseValues(joinModel, j.Relation.BaseFields) + if len(baseValues) == 0 { + return nil + } + m := hasManyModel{ + sliceTableModel: joinModel, + baseTable: baseTable, + rel: j.Relation, + + baseValues: baseValues, + } + if !m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } + return &m +} + +func (m *hasManyModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + var n int + + for rows.Next() { + if m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } else { + m.strct.Set(m.table.ZeroValue) + } + m.structInited = false + + m.scanIndex = 0 + m.structKey = m.structKey[:0] + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + if err := m.parkStruct(); err != nil { + return 0, err + } + + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +func (m *hasManyModel) Scan(src interface{}) error { + column := m.columns[m.scanIndex] + m.scanIndex++ + + field, err := m.table.Field(column) + if err != nil { + return err + } + + if err := field.ScanValue(m.strct, src); err != nil { + return err + } + + for _, f := range m.rel.JoinFields { + if f.Name == field.Name { + m.structKey = append(m.structKey, field.Value(m.strct).Interface()) + break + } + } + + return nil +} + +func (m *hasManyModel) parkStruct() error { + baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)] + if !ok { + return fmt.Errorf( + "bun: has-many relation=%s does not have base %s with id=%q (check join conditions)", + m.rel.Field.GoName, m.baseTable, m.structKey) + } + + for i, v := range baseValues { + if !m.sliceOfPtr { + v.Set(reflect.Append(v, m.strct)) + continue + } + + if i == 0 { + v.Set(reflect.Append(v, m.strct.Addr())) + continue + } + + clone := reflect.New(m.strct.Type()).Elem() + clone.Set(m.strct) + v.Set(reflect.Append(v, clone.Addr())) + } + + return nil +} + +func baseValues(model tableModel, fields []*schema.Field) map[internal.MapKey][]reflect.Value { + fieldIndex := model.Relation().Field.Index + m := make(map[internal.MapKey][]reflect.Value) + key := make([]interface{}, 0, len(fields)) + walk(model.Root(), model.ParentIndex(), func(v reflect.Value) { + key = modelKey(key[:0], v, fields) + mapKey := internal.NewMapKey(key) + m[mapKey] = append(m[mapKey], v.FieldByIndex(fieldIndex)) + }) + return m +} + +func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) []interface{} { + for _, f := range fields { + key = append(key, f.Value(strct).Interface()) + } + return key +} diff --git a/vendor/github.com/uptrace/bun/model_table_m2m.go b/vendor/github.com/uptrace/bun/model_table_m2m.go new file mode 100644 index 000000000..4357e3a8e --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_m2m.go @@ -0,0 +1,138 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type m2mModel struct { + *sliceTableModel + baseTable *schema.Table + rel *schema.Relation + + baseValues map[internal.MapKey][]reflect.Value + structKey []interface{} +} + +var _ tableModel = (*m2mModel)(nil) + +func newM2MModel(j *join) *m2mModel { + baseTable := j.BaseModel.Table() + joinModel := j.JoinModel.(*sliceTableModel) + baseValues := baseValues(joinModel, baseTable.PKs) + if len(baseValues) == 0 { + return nil + } + m := &m2mModel{ + sliceTableModel: joinModel, + baseTable: baseTable, + rel: j.Relation, + + baseValues: baseValues, + } + if !m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } + return m +} + +func (m *m2mModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + var n int + + for rows.Next() { + if m.sliceOfPtr { + m.strct = reflect.New(m.table.Type).Elem() + } else { + m.strct.Set(m.table.ZeroValue) + } + m.structInited = false + + m.scanIndex = 0 + m.structKey = m.structKey[:0] + if err := rows.Scan(dest...); err != nil { + return 0, err + } + + if err := m.parkStruct(); err != nil { + return 0, err + } + + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +func (m *m2mModel) Scan(src interface{}) error { + column := m.columns[m.scanIndex] + m.scanIndex++ + + field, ok := m.table.FieldMap[column] + if !ok { + return m.scanM2MColumn(column, src) + } + + if err := field.ScanValue(m.strct, src); err != nil { + return err + } + + for _, fk := range m.rel.M2MBaseFields { + if fk.Name == field.Name { + m.structKey = append(m.structKey, field.Value(m.strct).Interface()) + break + } + } + + return nil +} + +func (m *m2mModel) scanM2MColumn(column string, src interface{}) error { + for _, field := range m.rel.M2MBaseFields { + if field.Name == column { + dest := reflect.New(field.IndirectType).Elem() + if err := field.Scan(dest, src); err != nil { + return err + } + m.structKey = append(m.structKey, dest.Interface()) + break + } + } + + _, err := m.scanColumn(column, src) + return err +} + +func (m *m2mModel) parkStruct() error { + baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)] + if !ok { + return fmt.Errorf( + "bun: m2m relation=%s does not have base %s with key=%q (check join conditions)", + m.rel.Field.GoName, m.baseTable, m.structKey) + } + + for _, v := range baseValues { + if m.sliceOfPtr { + v.Set(reflect.Append(v, m.strct.Addr())) + } else { + v.Set(reflect.Append(v, m.strct)) + } + } + + return nil +} diff --git a/vendor/github.com/uptrace/bun/model_table_slice.go b/vendor/github.com/uptrace/bun/model_table_slice.go new file mode 100644 index 000000000..67e7c71e7 --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_slice.go @@ -0,0 +1,113 @@ +package bun + +import ( + "context" + "database/sql" + "reflect" + + "github.com/uptrace/bun/schema" +) + +type sliceTableModel struct { + structTableModel + + slice reflect.Value + sliceLen int + sliceOfPtr bool + nextElem func() reflect.Value +} + +var _ tableModel = (*sliceTableModel)(nil) + +func newSliceTableModel( + db *DB, dest interface{}, slice reflect.Value, elemType reflect.Type, +) *sliceTableModel { + m := &sliceTableModel{ + structTableModel: structTableModel{ + db: db, + table: db.Table(elemType), + dest: dest, + root: slice, + }, + + slice: slice, + sliceLen: slice.Len(), + nextElem: makeSliceNextElemFunc(slice), + } + m.init(slice.Type()) + return m +} + +func (m *sliceTableModel) init(sliceType reflect.Type) { + switch sliceType.Elem().Kind() { + case reflect.Ptr, reflect.Interface: + m.sliceOfPtr = true + } +} + +func (m *sliceTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join { + return m.join(m.slice, name, apply) +} + +func (m *sliceTableModel) Bind(bind reflect.Value) { + m.slice = bind.Field(m.index[len(m.index)-1]) +} + +func (m *sliceTableModel) SetCap(cap int) { + if cap > 100 { + cap = 100 + } + if m.slice.Cap() < cap { + m.slice.Set(reflect.MakeSlice(m.slice.Type(), 0, cap)) + } +} + +func (m *sliceTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + if m.slice.IsValid() && m.slice.Len() > 0 { + m.slice.Set(m.slice.Slice(0, 0)) + } + + var n int + + for rows.Next() { + m.strct = m.nextElem() + m.structInited = false + + if err := m.scanRow(ctx, rows, dest); err != nil { + return 0, err + } + + n++ + } + if err := rows.Err(); err != nil { + return 0, err + } + + return n, nil +} + +// Inherit these hooks from structTableModel. +var ( + _ schema.BeforeScanHook = (*sliceTableModel)(nil) + _ schema.AfterScanHook = (*sliceTableModel)(nil) +) + +func (m *sliceTableModel) updateSoftDeleteField() error { + sliceLen := m.slice.Len() + for i := 0; i < sliceLen; i++ { + strct := indirect(m.slice.Index(i)) + fv := m.table.SoftDeleteField.Value(strct) + if err := m.table.UpdateSoftDeleteField(fv); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/model_table_struct.go b/vendor/github.com/uptrace/bun/model_table_struct.go new file mode 100644 index 000000000..3bb0c14dd --- /dev/null +++ b/vendor/github.com/uptrace/bun/model_table_struct.go @@ -0,0 +1,345 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/schema" +) + +type structTableModel struct { + db *DB + table *schema.Table + + rel *schema.Relation + joins []join + + dest interface{} + root reflect.Value + index []int + + strct reflect.Value + structInited bool + structInitErr error + + columns []string + scanIndex int +} + +var _ tableModel = (*structTableModel)(nil) + +func newStructTableModel(db *DB, dest interface{}, table *schema.Table) *structTableModel { + return &structTableModel{ + db: db, + table: table, + dest: dest, + } +} + +func newStructTableModelValue(db *DB, dest interface{}, v reflect.Value) *structTableModel { + return &structTableModel{ + db: db, + table: db.Table(v.Type()), + dest: dest, + root: v, + strct: v, + } +} + +func (m *structTableModel) Value() interface{} { + return m.dest +} + +func (m *structTableModel) Table() *schema.Table { + return m.table +} + +func (m *structTableModel) Relation() *schema.Relation { + return m.rel +} + +func (m *structTableModel) Root() reflect.Value { + return m.root +} + +func (m *structTableModel) Index() []int { + return m.index +} + +func (m *structTableModel) ParentIndex() []int { + return m.index[:len(m.index)-len(m.rel.Field.Index)] +} + +func (m *structTableModel) Mount(host reflect.Value) { + m.strct = host.FieldByIndex(m.rel.Field.Index) + m.structInited = false +} + +func (m *structTableModel) initStruct() error { + if m.structInited { + return m.structInitErr + } + m.structInited = true + + switch m.strct.Kind() { + case reflect.Invalid: + m.structInitErr = errNilModel + return m.structInitErr + case reflect.Interface: + m.strct = m.strct.Elem() + } + + if m.strct.Kind() == reflect.Ptr { + if m.strct.IsNil() { + m.strct.Set(reflect.New(m.strct.Type().Elem())) + m.strct = m.strct.Elem() + } else { + m.strct = m.strct.Elem() + } + } + + m.mountJoins() + + return nil +} + +func (m *structTableModel) mountJoins() { + for i := range m.joins { + j := &m.joins[i] + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + j.JoinModel.Mount(m.strct) + } + } +} + +var _ schema.BeforeScanHook = (*structTableModel)(nil) + +func (m *structTableModel) BeforeScan(ctx context.Context) error { + if !m.table.HasBeforeScanHook() { + return nil + } + return callBeforeScanHook(ctx, m.strct.Addr()) +} + +var _ schema.AfterScanHook = (*structTableModel)(nil) + +func (m *structTableModel) AfterScan(ctx context.Context) error { + if !m.table.HasAfterScanHook() || !m.structInited { + return nil + } + + var firstErr error + + if err := callAfterScanHook(ctx, m.strct.Addr()); err != nil && firstErr == nil { + firstErr = err + } + + for _, j := range m.joins { + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + if err := j.JoinModel.AfterScan(ctx); err != nil && firstErr == nil { + firstErr = err + } + } + } + + return firstErr +} + +func (m *structTableModel) GetJoin(name string) *join { + for i := range m.joins { + j := &m.joins[i] + if j.Relation.Field.Name == name || j.Relation.Field.GoName == name { + return j + } + } + return nil +} + +func (m *structTableModel) GetJoins() []join { + return m.joins +} + +func (m *structTableModel) AddJoin(j join) *join { + m.joins = append(m.joins, j) + return &m.joins[len(m.joins)-1] +} + +func (m *structTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join { + return m.join(m.strct, name, apply) +} + +func (m *structTableModel) join( + bind reflect.Value, name string, apply func(*SelectQuery) *SelectQuery, +) *join { + path := strings.Split(name, ".") + index := make([]int, 0, len(path)) + + currJoin := join{ + BaseModel: m, + JoinModel: m, + } + var lastJoin *join + + for _, name := range path { + relation, ok := currJoin.JoinModel.Table().Relations[name] + if !ok { + return nil + } + + currJoin.Relation = relation + index = append(index, relation.Field.Index...) + + if j := currJoin.JoinModel.GetJoin(name); j != nil { + currJoin.BaseModel = j.BaseModel + currJoin.JoinModel = j.JoinModel + + lastJoin = j + } else { + model, err := newTableModelIndex(m.db, m.table, bind, index, relation) + if err != nil { + return nil + } + + currJoin.Parent = lastJoin + currJoin.BaseModel = currJoin.JoinModel + currJoin.JoinModel = model + + lastJoin = currJoin.BaseModel.AddJoin(currJoin) + } + } + + // No joins with such name. + if lastJoin == nil { + return nil + } + if apply != nil { + lastJoin.ApplyQueryFunc = apply + } + + return lastJoin +} + +func (m *structTableModel) updateSoftDeleteField() error { + fv := m.table.SoftDeleteField.Value(m.strct) + return m.table.UpdateSoftDeleteField(fv) +} + +func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error) { + if !rows.Next() { + return 0, rows.Err() + } + + if err := m.ScanRow(ctx, rows); err != nil { + return 0, err + } + + // For inserts, SQLite3 can return a row like it was inserted sucessfully and then + // an actual error for the next row. See issues/100. + if m.db.dialect.Name() == dialect.SQLite { + _ = rows.Next() + if err := rows.Err(); err != nil { + return 0, err + } + } + + return 1, nil +} + +func (m *structTableModel) ScanRow(ctx context.Context, rows *sql.Rows) error { + columns, err := rows.Columns() + if err != nil { + return err + } + + m.columns = columns + dest := makeDest(m, len(columns)) + + return m.scanRow(ctx, rows, dest) +} + +func (m *structTableModel) scanRow(ctx context.Context, rows *sql.Rows, dest []interface{}) error { + if err := m.BeforeScan(ctx); err != nil { + return err + } + + m.scanIndex = 0 + if err := rows.Scan(dest...); err != nil { + return err + } + + if err := m.AfterScan(ctx); err != nil { + return err + } + + return nil +} + +func (m *structTableModel) Scan(src interface{}) error { + column := m.columns[m.scanIndex] + m.scanIndex++ + + return m.ScanColumn(unquote(column), src) +} + +func (m *structTableModel) ScanColumn(column string, src interface{}) error { + if ok, err := m.scanColumn(column, src); ok { + return err + } + if column == "" || column[0] == '_' || m.db.flags.Has(discardUnknownColumns) { + return nil + } + return fmt.Errorf("bun: %s does not have column %q", m.table.TypeName, column) +} + +func (m *structTableModel) scanColumn(column string, src interface{}) (bool, error) { + if src != nil { + if err := m.initStruct(); err != nil { + return true, err + } + } + + if field, ok := m.table.FieldMap[column]; ok { + return true, field.ScanValue(m.strct, src) + } + + if joinName, column := splitColumn(column); joinName != "" { + if join := m.GetJoin(joinName); join != nil { + return true, join.JoinModel.ScanColumn(column, src) + } + if m.table.ModelName == joinName { + return true, m.ScanColumn(column, src) + } + } + + return false, nil +} + +func (m *structTableModel) AppendNamedArg( + fmter schema.Formatter, b []byte, name string, +) ([]byte, bool) { + return m.table.AppendNamedArg(fmter, b, name, m.strct) +} + +// sqlite3 sometimes does not unquote columns. +func unquote(s string) string { + if s == "" { + return s + } + if s[0] == '"' && s[len(s)-1] == '"' { + return s[1 : len(s)-1] + } + return s +} + +func splitColumn(s string) (string, string) { + if i := strings.Index(s, "__"); i >= 0 { + return s[:i], s[i+2:] + } + return "", s +} diff --git a/vendor/github.com/uptrace/bun/query_base.go b/vendor/github.com/uptrace/bun/query_base.go new file mode 100644 index 000000000..1a7c32720 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_base.go @@ -0,0 +1,874 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +const ( + wherePKFlag internal.Flag = 1 << iota + forceDeleteFlag + deletedFlag + allWithDeletedFlag +) + +type withQuery struct { + name string + query schema.QueryAppender +} + +// IConn is a common interface for *sql.DB, *sql.Conn, and *sql.Tx. +type IConn interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +var ( + _ IConn = (*sql.DB)(nil) + _ IConn = (*sql.Conn)(nil) + _ IConn = (*sql.Tx)(nil) + _ IConn = (*DB)(nil) + _ IConn = (*Conn)(nil) + _ IConn = (*Tx)(nil) +) + +// IDB is a common interface for *bun.DB, bun.Conn, and bun.Tx. +type IDB interface { + IConn + + NewValues(model interface{}) *ValuesQuery + NewSelect() *SelectQuery + NewInsert() *InsertQuery + NewUpdate() *UpdateQuery + NewDelete() *DeleteQuery + NewCreateTable() *CreateTableQuery + NewDropTable() *DropTableQuery + NewCreateIndex() *CreateIndexQuery + NewDropIndex() *DropIndexQuery + NewTruncateTable() *TruncateTableQuery + NewAddColumn() *AddColumnQuery + NewDropColumn() *DropColumnQuery +} + +var ( + _ IConn = (*DB)(nil) + _ IConn = (*Conn)(nil) + _ IConn = (*Tx)(nil) +) + +type baseQuery struct { + db *DB + conn IConn + + model model + err error + + tableModel tableModel + table *schema.Table + + with []withQuery + modelTable schema.QueryWithArgs + tables []schema.QueryWithArgs + columns []schema.QueryWithArgs + + flags internal.Flag +} + +func (q *baseQuery) DB() *DB { + return q.db +} + +func (q *baseQuery) GetModel() Model { + return q.model +} + +func (q *baseQuery) setConn(db IConn) { + // Unwrap Bun wrappers to not call query hooks twice. + switch db := db.(type) { + case *DB: + q.conn = db.DB + case Conn: + q.conn = db.Conn + case Tx: + q.conn = db.Tx + default: + q.conn = db + } +} + +// TODO: rename to setModel +func (q *baseQuery) setTableModel(modeli interface{}) { + model, err := newSingleModel(q.db, modeli) + if err != nil { + q.setErr(err) + return + } + + q.model = model + if tm, ok := model.(tableModel); ok { + q.tableModel = tm + q.table = tm.Table() + } +} + +func (q *baseQuery) setErr(err error) { + if q.err == nil { + q.err = err + } +} + +func (q *baseQuery) getModel(dest []interface{}) (model, error) { + if len(dest) == 0 { + if q.model != nil { + return q.model, nil + } + return nil, errNilModel + } + return newModel(q.db, dest) +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) checkSoftDelete() error { + if q.table == nil { + return errors.New("bun: can't use soft deletes without a table") + } + if q.table.SoftDeleteField == nil { + return fmt.Errorf("%s does not have a soft delete field", q.table) + } + if q.tableModel == nil { + return errors.New("bun: can't use soft deletes without a table model") + } + return nil +} + +// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models. +func (q *baseQuery) whereDeleted() { + if err := q.checkSoftDelete(); err != nil { + q.setErr(err) + return + } + q.flags = q.flags.Set(deletedFlag) + q.flags = q.flags.Remove(allWithDeletedFlag) +} + +// AllWithDeleted changes query to return all rows including soft deleted ones. +func (q *baseQuery) whereAllWithDeleted() { + if err := q.checkSoftDelete(); err != nil { + q.setErr(err) + return + } + q.flags = q.flags.Set(allWithDeletedFlag) + q.flags = q.flags.Remove(deletedFlag) +} + +func (q *baseQuery) isSoftDelete() bool { + if q.table != nil { + return q.table.SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag) + } + return false +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) addWith(name string, query schema.QueryAppender) { + q.with = append(q.with, withQuery{ + name: name, + query: query, + }) +} + +func (q *baseQuery) appendWith(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if len(q.with) == 0 { + return b, nil + } + + b = append(b, "WITH "...) + for i, with := range q.with { + if i > 0 { + b = append(b, ", "...) + } + + b = fmter.AppendIdent(b, with.name) + if q, ok := with.query.(schema.ColumnsAppender); ok { + b = append(b, " ("...) + b, err = q.AppendColumns(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ")"...) + } + + b = append(b, " AS ("...) + + b, err = with.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, ')') + } + b = append(b, ' ') + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) addTable(table schema.QueryWithArgs) { + q.tables = append(q.tables, table) +} + +func (q *baseQuery) addColumn(column schema.QueryWithArgs) { + q.columns = append(q.columns, column) +} + +func (q *baseQuery) excludeColumn(columns []string) { + if q.columns == nil { + for _, f := range q.table.Fields { + q.columns = append(q.columns, schema.UnsafeIdent(f.Name)) + } + } + + if len(columns) == 1 && columns[0] == "*" { + q.columns = make([]schema.QueryWithArgs, 0) + return + } + + for _, column := range columns { + if !q._excludeColumn(column) { + q.setErr(fmt.Errorf("bun: can't find column=%q", column)) + return + } + } +} + +func (q *baseQuery) _excludeColumn(column string) bool { + for i, col := range q.columns { + if col.Args == nil && col.Query == column { + q.columns = append(q.columns[:i], q.columns[i+1:]...) + return true + } + } + return false +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) modelHasTableName() bool { + return !q.modelTable.IsZero() || q.table != nil +} + +func (q *baseQuery) hasTables() bool { + return q.modelHasTableName() || len(q.tables) > 0 +} + +func (q *baseQuery) appendTables( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + return q._appendTables(fmter, b, false) +} + +func (q *baseQuery) appendTablesWithAlias( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + return q._appendTables(fmter, b, true) +} + +func (q *baseQuery) _appendTables( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + startLen := len(b) + + if q.modelHasTableName() { + if !q.modelTable.IsZero() { + b, err = q.modelTable.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } else { + b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects)) + if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects { + b = append(b, " AS "...) + b = append(b, q.table.SQLAlias...) + } + } + } + + for _, table := range q.tables { + if len(b) > startLen { + b = append(b, ", "...) + } + b, err = table.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *baseQuery) appendFirstTable(fmter schema.Formatter, b []byte) ([]byte, error) { + return q._appendFirstTable(fmter, b, false) +} + +func (q *baseQuery) appendFirstTableWithAlias( + fmter schema.Formatter, b []byte, +) ([]byte, error) { + return q._appendFirstTable(fmter, b, true) +} + +func (q *baseQuery) _appendFirstTable( + fmter schema.Formatter, b []byte, withAlias bool, +) ([]byte, error) { + if !q.modelTable.IsZero() { + return q.modelTable.AppendQuery(fmter, b) + } + + if q.table != nil { + b = fmter.AppendQuery(b, string(q.table.SQLName)) + if withAlias { + b = append(b, " AS "...) + b = append(b, q.table.SQLAlias...) + } + return b, nil + } + + if len(q.tables) > 0 { + return q.tables[0].AppendQuery(fmter, b) + } + + return nil, errors.New("bun: query does not have a table") +} + +func (q *baseQuery) hasMultiTables() bool { + if q.modelHasTableName() { + return len(q.tables) >= 1 + } + return len(q.tables) >= 2 +} + +func (q *baseQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { + tables := q.tables + if !q.modelHasTableName() { + tables = tables[1:] + } + for i, table := range tables { + if i > 0 { + b = append(b, ", "...) + } + b, err = table.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + for i, f := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *baseQuery) getFields() ([]*schema.Field, error) { + table := q.tableModel.Table() + + if len(q.columns) == 0 { + return table.Fields, nil + } + + fields, err := q._getFields(false) + if err != nil { + return nil, err + } + + return fields, nil +} + +func (q *baseQuery) getDataFields() ([]*schema.Field, error) { + if len(q.columns) == 0 { + return q.table.DataFields, nil + } + return q._getFields(true) +} + +func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) { + fields := make([]*schema.Field, 0, len(q.columns)) + for _, col := range q.columns { + if col.Args != nil { + continue + } + + field, err := q.table.Field(col.Query) + if err != nil { + return nil, err + } + + if omitPK && field.IsPK { + continue + } + + fields = append(fields, field) + } + return fields, nil +} + +func (q *baseQuery) scan( + ctx context.Context, + queryApp schema.QueryAppender, + query string, + model model, + hasDest bool, +) (res result, _ error) { + ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) + + rows, err := q.conn.QueryContext(ctx, query) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + defer rows.Close() + + n, err := model.ScanRows(ctx, rows) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + + res.n = n + if n == 0 && hasDest && isSingleRowModel(model) { + err = sql.ErrNoRows + } + + q.db.afterQuery(ctx, event, nil, err) + + return res, err +} + +func (q *baseQuery) exec( + ctx context.Context, + queryApp schema.QueryAppender, + query string, +) (res result, _ error) { + ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil) + + r, err := q.conn.ExecContext(ctx, query) + if err != nil { + q.db.afterQuery(ctx, event, nil, err) + return res, err + } + + res.r = r + + q.db.afterQuery(ctx, event, nil, err) + return res, nil +} + +//------------------------------------------------------------------------------ + +func (q *baseQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) { + if q.table == nil { + return b, false + } + + if m, ok := q.tableModel.(*structTableModel); ok { + if b, ok := m.AppendNamedArg(fmter, b, name); ok { + return b, ok + } + } + + switch name { + case "TableName": + b = fmter.AppendQuery(b, string(q.table.SQLName)) + return b, true + case "TableAlias": + b = fmter.AppendQuery(b, string(q.table.SQLAlias)) + return b, true + case "PKs": + b = appendColumns(b, "", q.table.PKs) + return b, true + case "TablePKs": + b = appendColumns(b, q.table.SQLAlias, q.table.PKs) + return b, true + case "Columns": + b = appendColumns(b, "", q.table.Fields) + return b, true + case "TableColumns": + b = appendColumns(b, q.table.SQLAlias, q.table.Fields) + return b, true + } + + return b, false +} + +func appendColumns(b []byte, table schema.Safe, fields []*schema.Field) []byte { + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + if len(table) > 0 { + b = append(b, table...) + b = append(b, '.') + } + b = append(b, f.SQLName...) + } + return b +} + +func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) schema.Formatter { + if fmter.IsNop() { + return fmter + } + return fmter.WithArg(model) +} + +//------------------------------------------------------------------------------ + +type whereBaseQuery struct { + baseQuery + + where []schema.QueryWithSep +} + +func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) { + q.where = append(q.where, where) +} + +func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) { + if len(where) == 0 { + return + } + + where[0].Sep = "" + + q.addWhere(schema.SafeQueryWithSep("", nil, sep+"(")) + q.where = append(q.where, where...) + q.addWhere(schema.SafeQueryWithSep("", nil, ")")) +} + +func (q *whereBaseQuery) mustAppendWhere( + fmter schema.Formatter, b []byte, withAlias bool, +) ([]byte, error) { + if len(q.where) == 0 && !q.flags.Has(wherePKFlag) { + err := errors.New("bun: Update and Delete queries require at least one Where") + return nil, err + } + return q.appendWhere(fmter, b, withAlias) +} + +func (q *whereBaseQuery) appendWhere( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) { + return b, nil + } + + b = append(b, " WHERE "...) + startLen := len(b) + + if len(q.where) > 0 { + b, err = appendWhere(fmter, b, q.where) + if err != nil { + return nil, err + } + } + + if q.isSoftDelete() { + if len(b) > startLen { + b = append(b, " AND "...) + } + if withAlias { + b = append(b, q.tableModel.Table().SQLAlias...) + b = append(b, '.') + } + b = append(b, q.tableModel.Table().SoftDeleteField.SQLName...) + if q.flags.Has(deletedFlag) { + b = append(b, " IS NOT NULL"...) + } else { + b = append(b, " IS NULL"...) + } + } + + if q.flags.Has(wherePKFlag) { + if len(b) > startLen { + b = append(b, " AND "...) + } + b, err = q.appendWherePK(fmter, b, withAlias) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func appendWhere( + fmter schema.Formatter, b []byte, where []schema.QueryWithSep, +) (_ []byte, err error) { + for i, where := range where { + if i > 0 || where.Sep == "(" { + b = append(b, where.Sep...) + } + + if where.Query == "" && where.Args == nil { + continue + } + + b = append(b, '(') + b, err = where.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + return b, nil +} + +func (q *whereBaseQuery) appendWherePK( + fmter schema.Formatter, b []byte, withAlias bool, +) (_ []byte, err error) { + if q.table == nil { + err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model) + return nil, err + } + if err := q.table.CheckPKs(); err != nil { + return nil, err + } + + switch model := q.tableModel.(type) { + case *structTableModel: + return q.appendWherePKStruct(fmter, b, model, withAlias) + case *sliceTableModel: + return q.appendWherePKSlice(fmter, b, model, withAlias) + } + + return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel) +} + +func (q *whereBaseQuery) appendWherePKStruct( + fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool, +) (_ []byte, err error) { + if !model.strct.IsValid() { + return nil, errNilModel + } + + isTemplate := fmter.IsNop() + b = append(b, '(') + for i, f := range q.table.PKs { + if i > 0 { + b = append(b, " AND "...) + } + if withAlias { + b = append(b, q.table.SQLAlias...) + b = append(b, '.') + } + b = append(b, f.SQLName...) + b = append(b, " = "...) + if isTemplate { + b = append(b, '?') + } else { + b = f.AppendValue(fmter, b, model.strct) + } + } + b = append(b, ')') + return b, nil +} + +func (q *whereBaseQuery) appendWherePKSlice( + fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool, +) (_ []byte, err error) { + if len(q.table.PKs) > 1 { + b = append(b, '(') + } + if withAlias { + b = appendColumns(b, q.table.SQLAlias, q.table.PKs) + } else { + b = appendColumns(b, "", q.table.PKs) + } + if len(q.table.PKs) > 1 { + b = append(b, ')') + } + + b = append(b, " IN ("...) + + isTemplate := fmter.IsNop() + slice := model.slice + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + if isTemplate { + break + } + b = append(b, ", "...) + } + + el := indirect(slice.Index(i)) + + if len(q.table.PKs) > 1 { + b = append(b, '(') + } + for i, f := range q.table.PKs { + if i > 0 { + b = append(b, ", "...) + } + if isTemplate { + b = append(b, '?') + } else { + b = f.AppendValue(fmter, b, el) + } + } + if len(q.table.PKs) > 1 { + b = append(b, ')') + } + } + + b = append(b, ')') + + return b, nil +} + +//------------------------------------------------------------------------------ + +type returningQuery struct { + returning []schema.QueryWithArgs + returningFields []*schema.Field +} + +func (q *returningQuery) addReturning(ret schema.QueryWithArgs) { + q.returning = append(q.returning, ret) +} + +func (q *returningQuery) addReturningField(field *schema.Field) { + if len(q.returning) > 0 { + return + } + for _, f := range q.returningFields { + if f == field { + return + } + } + q.returningFields = append(q.returningFields, field) +} + +func (q *returningQuery) hasReturning() bool { + if len(q.returning) == 1 { + switch q.returning[0].Query { + case "null", "NULL": + return false + } + } + return len(q.returning) > 0 || len(q.returningFields) > 0 +} + +func (q *returningQuery) appendReturning( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + if !q.hasReturning() { + return b, nil + } + + b = append(b, " RETURNING "...) + + for i, f := range q.returning { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if len(q.returning) > 0 { + return b, nil + } + + b = appendColumns(b, "", q.returningFields) + return b, nil +} + +//------------------------------------------------------------------------------ + +type columnValue struct { + column string + value schema.QueryWithArgs +} + +type customValueQuery struct { + modelValues map[string]schema.QueryWithArgs + extraValues []columnValue +} + +func (q *customValueQuery) addValue( + table *schema.Table, column string, value string, args []interface{}, +) { + if _, ok := table.FieldMap[column]; ok { + if q.modelValues == nil { + q.modelValues = make(map[string]schema.QueryWithArgs) + } + q.modelValues[column] = schema.SafeQuery(value, args) + } else { + q.extraValues = append(q.extraValues, columnValue{ + column: column, + value: schema.SafeQuery(value, args), + }) + } +} + +//------------------------------------------------------------------------------ + +type setQuery struct { + set []schema.QueryWithArgs +} + +func (q *setQuery) addSet(set schema.QueryWithArgs) { + q.set = append(q.set, set) +} + +func (q setQuery) appendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) { + for i, f := range q.set { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +//------------------------------------------------------------------------------ + +type cascadeQuery struct { + restrict bool +} + +func (q cascadeQuery) appendCascade(fmter schema.Formatter, b []byte) []byte { + if !fmter.HasFeature(feature.TableCascade) { + return b + } + if q.restrict { + b = append(b, " RESTRICT"...) + } else { + b = append(b, " CASCADE"...) + } + return b +} diff --git a/vendor/github.com/uptrace/bun/query_column_add.go b/vendor/github.com/uptrace/bun/query_column_add.go new file mode 100644 index 000000000..ce2f60bf0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_column_add.go @@ -0,0 +1,105 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type AddColumnQuery struct { + baseQuery +} + +func NewAddColumnQuery(db *DB) *AddColumnQuery { + q := &AddColumnQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *AddColumnQuery) Conn(db IConn) *AddColumnQuery { + q.setConn(db) + return q +} + +func (q *AddColumnQuery) Model(model interface{}) *AddColumnQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) Table(tables ...string) *AddColumnQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *AddColumnQuery) TableExpr(query string, args ...interface{}) *AddColumnQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *AddColumnQuery) ModelTableExpr(query string, args ...interface{}) *AddColumnQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) ColumnExpr(query string, args ...interface{}) *AddColumnQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if len(q.columns) != 1 { + return nil, fmt.Errorf("bun: AddColumnQuery requires exactly one column") + } + + b = append(b, "ALTER TABLE "...) + + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " ADD "...) + + b, err = q.columns[0].AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *AddColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_column_drop.go b/vendor/github.com/uptrace/bun/query_column_drop.go new file mode 100644 index 000000000..5684beeb3 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_column_drop.go @@ -0,0 +1,112 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DropColumnQuery struct { + baseQuery +} + +func NewDropColumnQuery(db *DB) *DropColumnQuery { + q := &DropColumnQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *DropColumnQuery) Conn(db IConn) *DropColumnQuery { + q.setConn(db) + return q +} + +func (q *DropColumnQuery) Model(model interface{}) *DropColumnQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) Table(tables ...string) *DropColumnQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *DropColumnQuery) TableExpr(query string, args ...interface{}) *DropColumnQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *DropColumnQuery) ModelTableExpr(query string, args ...interface{}) *DropColumnQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) Column(columns ...string) *DropColumnQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *DropColumnQuery) ColumnExpr(query string, args ...interface{}) *DropColumnQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if len(q.columns) != 1 { + return nil, fmt.Errorf("bun: DropColumnQuery requires exactly one column") + } + + b = append(b, "ALTER TABLE "...) + + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " DROP COLUMN "...) + + b, err = q.columns[0].AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *DropColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_delete.go b/vendor/github.com/uptrace/bun/query_delete.go new file mode 100644 index 000000000..c0c5039c7 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_delete.go @@ -0,0 +1,256 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DeleteQuery struct { + whereBaseQuery + returningQuery +} + +func NewDeleteQuery(db *DB) *DeleteQuery { + q := &DeleteQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *DeleteQuery) Conn(db IConn) *DeleteQuery { + q.setConn(db) + return q +} + +func (q *DeleteQuery) Model(model interface{}) *DeleteQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the DeleteQuery as an argument. +func (q *DeleteQuery) Apply(fn func(*DeleteQuery) *DeleteQuery) *DeleteQuery { + return fn(q) +} + +func (q *DeleteQuery) With(name string, query schema.QueryAppender) *DeleteQuery { + q.addWith(name, query) + return q +} + +func (q *DeleteQuery) Table(tables ...string) *DeleteQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *DeleteQuery) TableExpr(query string, args ...interface{}) *DeleteQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *DeleteQuery) ModelTableExpr(query string, args ...interface{}) *DeleteQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DeleteQuery) WherePK() *DeleteQuery { + q.flags = q.flags.Set(wherePKFlag) + return q +} + +func (q *DeleteQuery) Where(query string, args ...interface{}) *DeleteQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *DeleteQuery) WhereOr(query string, args ...interface{}) *DeleteQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +func (q *DeleteQuery) WhereGroup(sep string, fn func(*DeleteQuery) *DeleteQuery) *DeleteQuery { + saved := q.where + q.where = nil + + q = fn(q) + + where := q.where + q.where = saved + + q.addWhereGroup(sep, where) + + return q +} + +func (q *DeleteQuery) WhereDeleted() *DeleteQuery { + q.whereDeleted() + return q +} + +func (q *DeleteQuery) WhereAllWithDeleted() *DeleteQuery { + q.whereAllWithDeleted() + return q +} + +func (q *DeleteQuery) ForceDelete() *DeleteQuery { + q.flags = q.flags.Set(forceDeleteFlag) + return q +} + +//------------------------------------------------------------------------------ + +// Returning adds a RETURNING clause to the query. +// +// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. +func (q *DeleteQuery) Returning(query string, args ...interface{}) *DeleteQuery { + q.addReturning(schema.SafeQuery(query, args)) + return q +} + +func (q *DeleteQuery) hasReturning() bool { + if !q.db.features.Has(feature.Returning) { + return false + } + return q.returningQuery.hasReturning() +} + +//------------------------------------------------------------------------------ + +func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + if q.isSoftDelete() { + if err := q.tableModel.updateSoftDeleteField(); err != nil { + return nil, err + } + + upd := UpdateQuery{ + whereBaseQuery: q.whereBaseQuery, + returningQuery: q.returningQuery, + } + upd.Column(q.table.SoftDeleteField.Name) + return upd.AppendQuery(fmter, b) + } + + q = q.WhereAllWithDeleted() + withAlias := q.db.features.Has(feature.DeleteTableAlias) + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, "DELETE FROM "...) + + if withAlias { + b, err = q.appendFirstTableWithAlias(fmter, b) + } else { + b, err = q.appendFirstTable(fmter, b) + } + if err != nil { + return nil, err + } + + if q.hasMultiTables() { + b = append(b, " USING "...) + b, err = q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + } + + b, err = q.mustAppendWhere(fmter, b, withAlias) + if err != nil { + return nil, err + } + + if len(q.returning) > 0 { + b, err = q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *DeleteQuery) isSoftDelete() bool { + return q.tableModel != nil && q.table.SoftDeleteField != nil && !q.flags.Has(forceDeleteFlag) +} + +//------------------------------------------------------------------------------ + +func (q *DeleteQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeDeleteHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + var res sql.Result + + if hasDest := len(dest) > 0; hasDest || q.hasReturning() { + model, err := q.getModel(dest) + if err != nil { + return nil, err + } + + res, err = q.scan(ctx, q, query, model, hasDest) + if err != nil { + return nil, err + } + } else { + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + } + + if q.table != nil { + if err := q.afterDeleteHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *DeleteQuery) beforeDeleteHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeDeleteHook); ok { + if err := hook.BeforeDelete(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *DeleteQuery) afterDeleteHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterDeleteHook); ok { + if err := hook.AfterDelete(ctx, q); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_index_create.go b/vendor/github.com/uptrace/bun/query_index_create.go new file mode 100644 index 000000000..de7eb7aa0 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_index_create.go @@ -0,0 +1,242 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type CreateIndexQuery struct { + whereBaseQuery + + unique bool + fulltext bool + spatial bool + concurrently bool + ifNotExists bool + + index schema.QueryWithArgs + using schema.QueryWithArgs + include []schema.QueryWithArgs +} + +func NewCreateIndexQuery(db *DB) *CreateIndexQuery { + q := &CreateIndexQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *CreateIndexQuery) Conn(db IConn) *CreateIndexQuery { + q.setConn(db) + return q +} + +func (q *CreateIndexQuery) Model(model interface{}) *CreateIndexQuery { + q.setTableModel(model) + return q +} + +func (q *CreateIndexQuery) Unique() *CreateIndexQuery { + q.unique = true + return q +} + +func (q *CreateIndexQuery) Concurrently() *CreateIndexQuery { + q.concurrently = true + return q +} + +func (q *CreateIndexQuery) IfNotExists() *CreateIndexQuery { + q.ifNotExists = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Index(query string) *CreateIndexQuery { + q.index = schema.UnsafeIdent(query) + return q +} + +func (q *CreateIndexQuery) IndexExpr(query string, args ...interface{}) *CreateIndexQuery { + q.index = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Table(tables ...string) *CreateIndexQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *CreateIndexQuery) TableExpr(query string, args ...interface{}) *CreateIndexQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *CreateIndexQuery) ModelTableExpr(query string, args ...interface{}) *CreateIndexQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +func (q *CreateIndexQuery) Using(query string, args ...interface{}) *CreateIndexQuery { + q.using = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Column(columns ...string) *CreateIndexQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *CreateIndexQuery) ColumnExpr(query string, args ...interface{}) *CreateIndexQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +func (q *CreateIndexQuery) ExcludeColumn(columns ...string) *CreateIndexQuery { + q.excludeColumn(columns) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Include(columns ...string) *CreateIndexQuery { + for _, column := range columns { + q.include = append(q.include, schema.UnsafeIdent(column)) + } + return q +} + +func (q *CreateIndexQuery) IncludeExpr(query string, args ...interface{}) *CreateIndexQuery { + q.include = append(q.include, schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Where(query string, args ...interface{}) *CreateIndexQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *CreateIndexQuery) WhereOr(query string, args ...interface{}) *CreateIndexQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + b = append(b, "CREATE "...) + + if q.unique { + b = append(b, "UNIQUE "...) + } + if q.fulltext { + b = append(b, "FULLTEXT "...) + } + if q.spatial { + b = append(b, "SPATIAL "...) + } + + b = append(b, "INDEX "...) + + if q.concurrently { + b = append(b, "CONCURRENTLY "...) + } + if q.ifNotExists { + b = append(b, "IF NOT EXISTS "...) + } + + b, err = q.index.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " ON "...) + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + if !q.using.IsZero() { + b = append(b, " USING "...) + b, err = q.using.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, " ("...) + for i, col := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, ')') + + if len(q.include) > 0 { + b = append(b, " INCLUDE ("...) + for i, col := range q.include { + if i > 0 { + b = append(b, ", "...) + } + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, ')') + } + + if len(q.where) > 0 { + b, err = appendWhere(fmter, b, q.where) + if err != nil { + return nil, err + } + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *CreateIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_index_drop.go b/vendor/github.com/uptrace/bun/query_index_drop.go new file mode 100644 index 000000000..c922ff04f --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_index_drop.go @@ -0,0 +1,105 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DropIndexQuery struct { + baseQuery + cascadeQuery + + concurrently bool + ifExists bool + + index schema.QueryWithArgs +} + +func NewDropIndexQuery(db *DB) *DropIndexQuery { + q := &DropIndexQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *DropIndexQuery) Conn(db IConn) *DropIndexQuery { + q.setConn(db) + return q +} + +func (q *DropIndexQuery) Model(model interface{}) *DropIndexQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropIndexQuery) Concurrently() *DropIndexQuery { + q.concurrently = true + return q +} + +func (q *DropIndexQuery) IfExists() *DropIndexQuery { + q.ifExists = true + return q +} + +func (q *DropIndexQuery) Restrict() *DropIndexQuery { + q.restrict = true + return q +} + +func (q *DropIndexQuery) Index(query string, args ...interface{}) *DropIndexQuery { + q.index = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropIndexQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + b = append(b, "DROP INDEX "...) + + if q.concurrently { + b = append(b, "CONCURRENTLY "...) + } + if q.ifExists { + b = append(b, "IF EXISTS "...) + } + + b, err = q.index.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + b = q.appendCascade(fmter, b) + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *DropIndexQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_insert.go b/vendor/github.com/uptrace/bun/query_insert.go new file mode 100644 index 000000000..efddee407 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_insert.go @@ -0,0 +1,551 @@ +package bun + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type InsertQuery struct { + whereBaseQuery + returningQuery + customValueQuery + + onConflict schema.QueryWithArgs + setQuery + + ignore bool + replace bool +} + +func NewInsertQuery(db *DB) *InsertQuery { + q := &InsertQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *InsertQuery) Conn(db IConn) *InsertQuery { + q.setConn(db) + return q +} + +func (q *InsertQuery) Model(model interface{}) *InsertQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the SelectQuery as an argument. +func (q *InsertQuery) Apply(fn func(*InsertQuery) *InsertQuery) *InsertQuery { + return fn(q) +} + +func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery { + q.addWith(name, query) + return q +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) Table(tables ...string) *InsertQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *InsertQuery) TableExpr(query string, args ...interface{}) *InsertQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *InsertQuery) ModelTableExpr(query string, args ...interface{}) *InsertQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) Column(columns ...string) *InsertQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery { + q.excludeColumn(columns) + return q +} + +// Value overwrites model value for the column in INSERT and UPDATE queries. +func (q *InsertQuery) Value(column string, value string, args ...interface{}) *InsertQuery { + if q.table == nil { + q.err = errNilModel + return q + } + q.addValue(q.table, column, value, args) + return q +} + +func (q *InsertQuery) Where(query string, args ...interface{}) *InsertQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *InsertQuery) WhereOr(query string, args ...interface{}) *InsertQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +//------------------------------------------------------------------------------ + +// Returning adds a RETURNING clause to the query. +// +// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. +func (q *InsertQuery) Returning(query string, args ...interface{}) *InsertQuery { + q.addReturning(schema.SafeQuery(query, args)) + return q +} + +func (q *InsertQuery) hasReturning() bool { + if !q.db.features.Has(feature.Returning) { + return false + } + return q.returningQuery.hasReturning() +} + +//------------------------------------------------------------------------------ + +// Ignore generates an `INSERT IGNORE INTO` query (MySQL). +func (q *InsertQuery) Ignore() *InsertQuery { + q.ignore = true + return q +} + +// Replaces generates a `REPLACE INTO` query (MySQL). +func (q *InsertQuery) Replace() *InsertQuery { + q.replace = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + if q.replace { + b = append(b, "REPLACE "...) + } else { + b = append(b, "INSERT "...) + if q.ignore { + b = append(b, "IGNORE "...) + } + } + b = append(b, "INTO "...) + + if q.db.features.Has(feature.InsertTableAlias) && !q.onConflict.IsZero() { + b, err = q.appendFirstTableWithAlias(fmter, b) + } else { + b, err = q.appendFirstTable(fmter, b) + } + if err != nil { + return nil, err + } + + b, err = q.appendColumnsValues(fmter, b) + if err != nil { + return nil, err + } + + b, err = q.appendOn(fmter, b) + if err != nil { + return nil, err + } + + if q.hasReturning() { + b, err = q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) appendColumnsValues( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + if q.hasMultiTables() { + if q.columns != nil { + b = append(b, " ("...) + b, err = q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ")"...) + } + + b = append(b, " SELECT * FROM "...) + b, err = q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + + return b, nil + } + + if m, ok := q.model.(*mapModel); ok { + return m.appendColumnsValues(fmter, b), nil + } + if _, ok := q.model.(*mapSliceModel); ok { + return nil, fmt.Errorf("Insert(*[]map[string]interface{}) is not supported") + } + + if q.model == nil { + return nil, errNilModel + } + + fields, err := q.getFields() + if err != nil { + return nil, err + } + + b = append(b, " ("...) + b = q.appendFields(fmter, b, fields) + b = append(b, ") VALUES ("...) + + switch model := q.tableModel.(type) { + case *structTableModel: + b, err = q.appendStructValues(fmter, b, fields, model.strct) + if err != nil { + return nil, err + } + case *sliceTableModel: + b, err = q.appendSliceValues(fmter, b, fields, model.slice) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("bun: Insert does not support %T", q.tableModel) + } + + b = append(b, ')') + + return b, nil +} + +func (q *InsertQuery) appendStructValues( + fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value, +) (_ []byte, err error) { + isTemplate := fmter.IsNop() + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + app, ok := q.modelValues[f.Name] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + q.addReturningField(f) + continue + } + + switch { + case isTemplate: + b = append(b, '?') + case f.NullZero && f.HasZeroValue(strct): + if q.db.features.Has(feature.DefaultPlaceholder) { + b = append(b, "DEFAULT"...) + } else if f.SQLDefault != "" { + b = append(b, f.SQLDefault...) + } else { + b = append(b, "NULL"...) + } + q.addReturningField(f) + default: + b = f.AppendValue(fmter, b, strct) + } + } + + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) appendSliceValues( + fmter schema.Formatter, b []byte, fields []*schema.Field, slice reflect.Value, +) (_ []byte, err error) { + if fmter.IsNop() { + return q.appendStructValues(fmter, b, fields, reflect.Value{}) + } + + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, "), ("...) + } + el := indirect(slice.Index(i)) + b, err = q.appendStructValues(fmter, b, fields, el) + if err != nil { + return nil, err + } + } + + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *InsertQuery) getFields() ([]*schema.Field, error) { + if q.db.features.Has(feature.DefaultPlaceholder) || len(q.columns) > 0 { + return q.baseQuery.getFields() + } + + var strct reflect.Value + + switch model := q.tableModel.(type) { + case *structTableModel: + strct = model.strct + case *sliceTableModel: + if model.sliceLen == 0 { + return nil, fmt.Errorf("bun: Insert(empty %T)", model.slice.Type()) + } + strct = indirect(model.slice.Index(0)) + } + + fields := make([]*schema.Field, 0, len(q.table.Fields)) + + for _, f := range q.table.Fields { + if f.NotNull && f.NullZero && f.SQLDefault == "" && f.HasZeroValue(strct) { + q.addReturningField(f) + continue + } + fields = append(fields, f) + } + + return fields, nil +} + +func (q *InsertQuery) appendFields( + fmter schema.Formatter, b []byte, fields []*schema.Field, +) []byte { + b = appendColumns(b, "", fields) + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + b = fmter.AppendIdent(b, v.column) + } + return b +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) On(s string, args ...interface{}) *InsertQuery { + q.onConflict = schema.SafeQuery(s, args) + return q +} + +func (q *InsertQuery) Set(query string, args ...interface{}) *InsertQuery { + q.addSet(schema.SafeQuery(query, args)) + return q +} + +func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.onConflict.IsZero() { + return b, nil + } + + b = append(b, " ON "...) + b, err = q.onConflict.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if len(q.set) > 0 { + if fmter.HasFeature(feature.OnDuplicateKey) { + b = append(b, ' ') + } else { + b = append(b, " SET "...) + } + + b, err = q.appendSet(fmter, b) + if err != nil { + return nil, err + } + } else if len(q.columns) > 0 { + fields, err := q.getDataFields() + if err != nil { + return nil, err + } + + if len(fields) == 0 { + fields = q.tableModel.Table().DataFields + } + + b = q.appendSetExcluded(b, fields) + } + + b, err = q.appendWhere(fmter, b, true) + if err != nil { + return nil, err + } + + return b, nil +} + +func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte { + b = append(b, " SET "...) + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, f.SQLName...) + b = append(b, " = EXCLUDED."...) + b = append(b, f.SQLName...) + } + return b +} + +//------------------------------------------------------------------------------ + +func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeInsertHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + var res sql.Result + + if hasDest := len(dest) > 0; hasDest || q.hasReturning() { + model, err := q.getModel(dest) + if err != nil { + return nil, err + } + + res, err = q.scan(ctx, q, query, model, hasDest) + if err != nil { + return nil, err + } + } else { + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + if err := q.tryLastInsertID(res, dest); err != nil { + return nil, err + } + } + + if q.table != nil { + if err := q.afterInsertHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *InsertQuery) beforeInsertHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeInsertHook); ok { + if err := hook.BeforeInsert(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *InsertQuery) afterInsertHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterInsertHook); ok { + if err := hook.AfterInsert(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error { + if q.db.features.Has(feature.Returning) || q.table == nil || len(q.table.PKs) != 1 { + return nil + } + + id, err := res.LastInsertId() + if err != nil { + return err + } + if id == 0 { + return nil + } + + model, err := q.getModel(dest) + if err != nil { + return err + } + + pk := q.table.PKs[0] + switch model := model.(type) { + case *structTableModel: + if err := pk.ScanValue(model.strct, id); err != nil { + return err + } + case *sliceTableModel: + sliceLen := model.slice.Len() + for i := 0; i < sliceLen; i++ { + strct := indirect(model.slice.Index(i)) + if err := pk.ScanValue(strct, id); err != nil { + return err + } + id++ + } + } + + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_select.go b/vendor/github.com/uptrace/bun/query_select.go new file mode 100644 index 000000000..1f63686ad --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_select.go @@ -0,0 +1,830 @@ +package bun + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + "strconv" + "strings" + "sync" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type union struct { + expr string + query *SelectQuery +} + +type SelectQuery struct { + whereBaseQuery + + distinctOn []schema.QueryWithArgs + joins []joinQuery + group []schema.QueryWithArgs + having []schema.QueryWithArgs + order []schema.QueryWithArgs + limit int32 + offset int32 + selFor schema.QueryWithArgs + + union []union +} + +func NewSelectQuery(db *DB) *SelectQuery { + return &SelectQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } +} + +func (q *SelectQuery) Conn(db IConn) *SelectQuery { + q.setConn(db) + return q +} + +func (q *SelectQuery) Model(model interface{}) *SelectQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the SelectQuery as an argument. +func (q *SelectQuery) Apply(fn func(*SelectQuery) *SelectQuery) *SelectQuery { + return fn(q) +} + +func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery { + q.addWith(name, query) + return q +} + +func (q *SelectQuery) Distinct() *SelectQuery { + q.distinctOn = make([]schema.QueryWithArgs, 0) + return q +} + +func (q *SelectQuery) DistinctOn(query string, args ...interface{}) *SelectQuery { + q.distinctOn = append(q.distinctOn, schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Table(tables ...string) *SelectQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *SelectQuery) TableExpr(query string, args ...interface{}) *SelectQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *SelectQuery) ModelTableExpr(query string, args ...interface{}) *SelectQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Column(columns ...string) *SelectQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *SelectQuery) ColumnExpr(query string, args ...interface{}) *SelectQuery { + q.addColumn(schema.SafeQuery(query, args)) + return q +} + +func (q *SelectQuery) ExcludeColumn(columns ...string) *SelectQuery { + q.excludeColumn(columns) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) WherePK() *SelectQuery { + q.flags = q.flags.Set(wherePKFlag) + return q +} + +func (q *SelectQuery) Where(query string, args ...interface{}) *SelectQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *SelectQuery) WhereOr(query string, args ...interface{}) *SelectQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +func (q *SelectQuery) WhereGroup(sep string, fn func(*SelectQuery) *SelectQuery) *SelectQuery { + saved := q.where + q.where = nil + + q = fn(q) + + where := q.where + q.where = saved + + q.addWhereGroup(sep, where) + + return q +} + +func (q *SelectQuery) WhereDeleted() *SelectQuery { + q.whereDeleted() + return q +} + +func (q *SelectQuery) WhereAllWithDeleted() *SelectQuery { + q.whereAllWithDeleted() + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Group(columns ...string) *SelectQuery { + for _, column := range columns { + q.group = append(q.group, schema.UnsafeIdent(column)) + } + return q +} + +func (q *SelectQuery) GroupExpr(group string, args ...interface{}) *SelectQuery { + q.group = append(q.group, schema.SafeQuery(group, args)) + return q +} + +func (q *SelectQuery) Having(having string, args ...interface{}) *SelectQuery { + q.having = append(q.having, schema.SafeQuery(having, args)) + return q +} + +func (q *SelectQuery) Order(orders ...string) *SelectQuery { + for _, order := range orders { + if order == "" { + continue + } + + index := strings.IndexByte(order, ' ') + if index == -1 { + q.order = append(q.order, schema.UnsafeIdent(order)) + continue + } + + field := order[:index] + sort := order[index+1:] + + switch strings.ToUpper(sort) { + case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST", + "ASC NULLS LAST", "DESC NULLS LAST": + q.order = append(q.order, schema.SafeQuery("? ?", []interface{}{ + Ident(field), + Safe(sort), + })) + default: + q.order = append(q.order, schema.UnsafeIdent(order)) + } + } + return q +} + +func (q *SelectQuery) OrderExpr(query string, args ...interface{}) *SelectQuery { + q.order = append(q.order, schema.SafeQuery(query, args)) + return q +} + +func (q *SelectQuery) Limit(n int) *SelectQuery { + q.limit = int32(n) + return q +} + +func (q *SelectQuery) Offset(n int) *SelectQuery { + q.offset = int32(n) + return q +} + +func (q *SelectQuery) For(s string, args ...interface{}) *SelectQuery { + q.selFor = schema.SafeQuery(s, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Union(other *SelectQuery) *SelectQuery { + return q.addUnion(" UNION ", other) +} + +func (q *SelectQuery) UnionAll(other *SelectQuery) *SelectQuery { + return q.addUnion(" UNION ALL ", other) +} + +func (q *SelectQuery) Intersect(other *SelectQuery) *SelectQuery { + return q.addUnion(" INTERSECT ", other) +} + +func (q *SelectQuery) IntersectAll(other *SelectQuery) *SelectQuery { + return q.addUnion(" INTERSECT ALL ", other) +} + +func (q *SelectQuery) Except(other *SelectQuery) *SelectQuery { + return q.addUnion(" EXCEPT ", other) +} + +func (q *SelectQuery) ExceptAll(other *SelectQuery) *SelectQuery { + return q.addUnion(" EXCEPT ALL ", other) +} + +func (q *SelectQuery) addUnion(expr string, other *SelectQuery) *SelectQuery { + q.union = append(q.union, union{ + expr: expr, + query: other, + }) + return q +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Join(join string, args ...interface{}) *SelectQuery { + q.joins = append(q.joins, joinQuery{ + join: schema.SafeQuery(join, args), + }) + return q +} + +func (q *SelectQuery) JoinOn(cond string, args ...interface{}) *SelectQuery { + return q.joinOn(cond, args, " AND ") +} + +func (q *SelectQuery) JoinOnOr(cond string, args ...interface{}) *SelectQuery { + return q.joinOn(cond, args, " OR ") +} + +func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *SelectQuery { + if len(q.joins) == 0 { + q.err = errors.New("bun: query has no joins") + return q + } + j := &q.joins[len(q.joins)-1] + j.on = append(j.on, schema.SafeQueryWithSep(cond, args, sep)) + return q +} + +//------------------------------------------------------------------------------ + +// Relation adds a relation to the query. Relation name can be: +// - RelationName to select all columns, +// - RelationName.column_name, +// - RelationName._ to join relation without selecting relation columns. +func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQuery) *SelectQuery { + if q.tableModel == nil { + q.setErr(errNilModel) + return q + } + + var fn func(*SelectQuery) *SelectQuery + + if len(apply) == 1 { + fn = apply[0] + } else if len(apply) > 1 { + panic("only one apply function is supported") + } + + join := q.tableModel.Join(name, fn) + if join == nil { + q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name)) + return q + } + + return q +} + +func (q *SelectQuery) forEachHasOneJoin(fn func(*join) error) error { + if q.tableModel == nil { + return nil + } + return q._forEachHasOneJoin(fn, q.tableModel.GetJoins()) +} + +func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) error { + for i := range joins { + j := &joins[i] + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + if err := fn(j); err != nil { + return err + } + if err := q._forEachHasOneJoin(fn, j.JoinModel.GetJoins()); err != nil { + return err + } + } + } + return nil +} + +func (q *SelectQuery) selectJoins(ctx context.Context, joins []join) error { + var err error + for i := range joins { + j := &joins[i] + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + err = q.selectJoins(ctx, j.JoinModel.GetJoins()) + default: + err = j.Select(ctx, q.db.NewSelect()) + } + if err != nil { + return err + } + } + return nil +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + return q.appendQuery(fmter, b, false) +} + +func (q *SelectQuery) appendQuery( + fmter schema.Formatter, b []byte, count bool, +) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + cteCount := count && (len(q.group) > 0 || q.distinctOn != nil) + if cteCount { + b = append(b, "WITH _count_wrapper AS ("...) + } + + if len(q.union) > 0 { + b = append(b, '(') + } + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, "SELECT "...) + + if len(q.distinctOn) > 0 { + b = append(b, "DISTINCT ON ("...) + for i, app := range q.distinctOn { + if i > 0 { + b = append(b, ", "...) + } + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + b = append(b, ") "...) + } else if q.distinctOn != nil { + b = append(b, "DISTINCT "...) + } + + if count && !cteCount { + b = append(b, "count(*)"...) + } else { + b, err = q.appendColumns(fmter, b) + if err != nil { + return nil, err + } + } + + if q.hasTables() { + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + } + + if err := q.forEachHasOneJoin(func(j *join) error { + b = append(b, ' ') + b, err = j.appendHasOneJoin(fmter, b, q) + return err + }); err != nil { + return nil, err + } + + for _, j := range q.joins { + b, err = j.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + b, err = q.appendWhere(fmter, b, true) + if err != nil { + return nil, err + } + + if len(q.group) > 0 { + b = append(b, " GROUP BY "...) + for i, f := range q.group { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + if len(q.having) > 0 { + b = append(b, " HAVING "...) + for i, f := range q.having { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, '(') + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + } + + if !count { + b, err = q.appendOrder(fmter, b) + if err != nil { + return nil, err + } + + if q.limit != 0 { + b = append(b, " LIMIT "...) + b = strconv.AppendInt(b, int64(q.limit), 10) + } + + if q.offset != 0 { + b = append(b, " OFFSET "...) + b = strconv.AppendInt(b, int64(q.offset), 10) + } + + if !q.selFor.IsZero() { + b = append(b, " FOR "...) + b, err = q.selFor.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + } + + if len(q.union) > 0 { + b = append(b, ')') + + for _, u := range q.union { + b = append(b, u.expr...) + b = append(b, '(') + b, err = u.query.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + } + + if cteCount { + b = append(b, ") SELECT count(*) FROM _count_wrapper"...) + } + + return b, nil +} + +func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + start := len(b) + + switch { + case q.columns != nil: + for i, col := range q.columns { + if i > 0 { + b = append(b, ", "...) + } + + if col.Args == nil { + if field, ok := q.table.FieldMap[col.Query]; ok { + b = append(b, q.table.SQLAlias...) + b = append(b, '.') + b = append(b, field.SQLName...) + continue + } + } + + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + case q.table != nil: + if len(q.table.Fields) > 10 && fmter.IsNop() { + b = append(b, q.table.SQLAlias...) + b = append(b, '.') + b = dialect.AppendString(b, fmt.Sprintf("%d columns", len(q.table.Fields))) + } else { + b = appendColumns(b, q.table.SQLAlias, q.table.Fields) + } + default: + b = append(b, '*') + } + + if err := q.forEachHasOneJoin(func(j *join) error { + if len(b) != start { + b = append(b, ", "...) + start = len(b) + } + + b, err = q.appendHasOneColumns(fmter, b, j) + if err != nil { + return err + } + + return nil + }); err != nil { + return nil, err + } + + b = bytes.TrimSuffix(b, []byte(", ")) + + return b, nil +} + +func (q *SelectQuery) appendHasOneColumns( + fmter schema.Formatter, b []byte, join *join, +) (_ []byte, err error) { + join.applyQuery(q) + + if join.columns != nil { + for i, col := range join.columns { + if i > 0 { + b = append(b, ", "...) + } + + if col.Args == nil { + if field, ok := q.table.FieldMap[col.Query]; ok { + b = join.appendAlias(fmter, b) + b = append(b, '.') + b = append(b, field.SQLName...) + b = append(b, " AS "...) + b = join.appendAliasColumn(fmter, b, field.Name) + continue + } + } + + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil + } + + for i, field := range join.JoinModel.Table().Fields { + if i > 0 { + b = append(b, ", "...) + } + b = join.appendAlias(fmter, b) + b = append(b, '.') + b = append(b, field.SQLName...) + b = append(b, " AS "...) + b = join.appendAliasColumn(fmter, b, field.Name) + } + return b, nil +} + +func (q *SelectQuery) appendTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { + b = append(b, " FROM "...) + return q.appendTablesWithAlias(fmter, b) +} + +func (q *SelectQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if len(q.order) > 0 { + b = append(b, " ORDER BY "...) + + for i, f := range q.order { + if i > 0 { + b = append(b, ", "...) + } + b, err = f.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil + } + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + return q.conn.QueryContext(ctx, query) +} + +func (q *SelectQuery) Exec(ctx context.Context) (res sql.Result, err error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} + +func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error { + model, err := q.getModel(dest) + if err != nil { + return err + } + + if q.limit > 1 { + if model, ok := model.(interface{ SetCap(int) }); ok { + model.SetCap(int(q.limit)) + } + } + + if q.table != nil { + if err := q.beforeSelectHook(ctx); err != nil { + return err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return err + } + + query := internal.String(queryBytes) + + res, err := q.scan(ctx, q, query, model, true) + if err != nil { + return err + } + + if res.n > 0 { + if tableModel, ok := model.(tableModel); ok { + if err := q.selectJoins(ctx, tableModel.GetJoins()); err != nil { + return err + } + } + } + + if q.table != nil { + if err := q.afterSelectHook(ctx); err != nil { + return err + } + } + + return nil +} + +func (q *SelectQuery) beforeSelectHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeSelectHook); ok { + if err := hook.BeforeSelect(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *SelectQuery) afterSelectHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterSelectHook); ok { + if err := hook.AfterSelect(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *SelectQuery) Count(ctx context.Context) (int, error) { + qq := countQuery{q} + + queryBytes, err := qq.appendQuery(q.db.fmter, nil, true) + if err != nil { + return 0, err + } + + query := internal.String(queryBytes) + ctx, event := q.db.beforeQuery(ctx, qq, query, nil) + + var num int + err = q.conn.QueryRowContext(ctx, query).Scan(&num) + + q.db.afterQuery(ctx, event, nil, err) + + return num, err +} + +func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) { + var count int + var wg sync.WaitGroup + var mu sync.Mutex + var firstErr error + + if q.limit >= 0 { + wg.Add(1) + go func() { + defer wg.Done() + + if err := q.Scan(ctx, dest...); err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + + var err error + count, err = q.Count(ctx) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + + wg.Wait() + return count, firstErr +} + +//------------------------------------------------------------------------------ + +type joinQuery struct { + join schema.QueryWithArgs + on []schema.QueryWithSep +} + +func (j *joinQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + b = append(b, ' ') + + b, err = j.join.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + + if len(j.on) > 0 { + b = append(b, " ON "...) + for i, on := range j.on { + if i > 0 { + b = append(b, on.Sep...) + } + + b = append(b, '(') + b, err = on.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +type countQuery struct { + *SelectQuery +} + +func (q countQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + return q.appendQuery(fmter, b, true) +} diff --git a/vendor/github.com/uptrace/bun/query_table_create.go b/vendor/github.com/uptrace/bun/query_table_create.go new file mode 100644 index 000000000..0a4b3567c --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_table_create.go @@ -0,0 +1,275 @@ +package bun + +import ( + "context" + "database/sql" + "sort" + "strconv" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type CreateTableQuery struct { + baseQuery + + temp bool + ifNotExists bool + varchar int + + fks []schema.QueryWithArgs + partitionBy schema.QueryWithArgs + tablespace schema.QueryWithArgs +} + +func NewCreateTableQuery(db *DB) *CreateTableQuery { + q := &CreateTableQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *CreateTableQuery) Conn(db IConn) *CreateTableQuery { + q.setConn(db) + return q +} + +func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *CreateTableQuery) TableExpr(query string, args ...interface{}) *CreateTableQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *CreateTableQuery) ModelTableExpr(query string, args ...interface{}) *CreateTableQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *CreateTableQuery) Temp() *CreateTableQuery { + q.temp = true + return q +} + +func (q *CreateTableQuery) IfNotExists() *CreateTableQuery { + q.ifNotExists = true + return q +} + +func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery { + q.varchar = n + return q +} + +func (q *CreateTableQuery) ForeignKey(query string, args ...interface{}) *CreateTableQuery { + q.fks = append(q.fks, schema.SafeQuery(query, args)) + return q +} + +func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if q.table == nil { + return nil, errNilModel + } + + b = append(b, "CREATE "...) + if q.temp { + b = append(b, "TEMP "...) + } + b = append(b, "TABLE "...) + if q.ifNotExists { + b = append(b, "IF NOT EXISTS "...) + } + b, err = q.appendFirstTable(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, " ("...) + + for i, field := range q.table.Fields { + if i > 0 { + b = append(b, ", "...) + } + + b = append(b, field.SQLName...) + b = append(b, " "...) + b = q.appendSQLType(b, field) + if field.NotNull { + b = append(b, " NOT NULL"...) + } + if q.db.features.Has(feature.AutoIncrement) && field.AutoIncrement { + b = append(b, " AUTO_INCREMENT"...) + } + if field.SQLDefault != "" { + b = append(b, " DEFAULT "...) + b = append(b, field.SQLDefault...) + } + } + + b = q.appendPKConstraint(b, q.table.PKs) + b = q.appendUniqueConstraints(fmter, b) + b, err = q.appenFKConstraints(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, ")"...) + + if !q.partitionBy.IsZero() { + b = append(b, " PARTITION BY "...) + b, err = q.partitionBy.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + if !q.tablespace.IsZero() { + b = append(b, " TABLESPACE "...) + b, err = q.tablespace.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte { + if field.CreateTableSQLType != field.DiscoveredSQLType { + return append(b, field.CreateTableSQLType...) + } + + if q.varchar > 0 && + field.CreateTableSQLType == sqltype.VarChar { + b = append(b, "varchar("...) + b = strconv.AppendInt(b, int64(q.varchar), 10) + b = append(b, ")"...) + return b + } + + return append(b, field.CreateTableSQLType...) +} + +func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte { + unique := q.table.Unique + + keys := make([]string, 0, len(unique)) + for key := range unique { + keys = append(keys, key) + } + sort.Strings(keys) + + for _, key := range keys { + b = q.appendUniqueConstraint(fmter, b, key, unique[key]) + } + + return b +} + +func (q *CreateTableQuery) appendUniqueConstraint( + fmter schema.Formatter, b []byte, name string, fields []*schema.Field, +) []byte { + if name != "" { + b = append(b, ", CONSTRAINT "...) + b = fmter.AppendIdent(b, name) + } else { + b = append(b, ","...) + } + b = append(b, " UNIQUE ("...) + b = appendColumns(b, "", fields) + b = append(b, ")"...) + + return b +} + +func (q *CreateTableQuery) appenFKConstraints( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + for _, fk := range q.fks { + b = append(b, ", FOREIGN KEY "...) + b, err = fk.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []byte { + if len(pks) == 0 { + return b + } + + b = append(b, ", PRIMARY KEY ("...) + b = appendColumns(b, "", pks) + b = append(b, ")"...) + return b +} + +//------------------------------------------------------------------------------ + +func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if err := q.beforeCreateTableHook(ctx); err != nil { + return nil, err + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + if q.table != nil { + if err := q.afterCreateTableHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *CreateTableQuery) beforeCreateTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeCreateTableHook); ok { + if err := hook.BeforeCreateTable(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterCreateTableHook); ok { + if err := hook.AfterCreateTable(ctx, q); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_table_drop.go b/vendor/github.com/uptrace/bun/query_table_drop.go new file mode 100644 index 000000000..2c30171c1 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_table_drop.go @@ -0,0 +1,137 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type DropTableQuery struct { + baseQuery + cascadeQuery + + ifExists bool +} + +func NewDropTableQuery(db *DB) *DropTableQuery { + q := &DropTableQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *DropTableQuery) Conn(db IConn) *DropTableQuery { + q.setConn(db) + return q +} + +func (q *DropTableQuery) Model(model interface{}) *DropTableQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) Table(tables ...string) *DropTableQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *DropTableQuery) TableExpr(query string, args ...interface{}) *DropTableQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *DropTableQuery) ModelTableExpr(query string, args ...interface{}) *DropTableQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) IfExists() *DropTableQuery { + q.ifExists = true + return q +} + +func (q *DropTableQuery) Restrict() *DropTableQuery { + q.restrict = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + b = append(b, "DROP TABLE "...) + if q.ifExists { + b = append(b, "IF EXISTS "...) + } + + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + + b = q.appendCascade(fmter, b) + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *DropTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeDropTableHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + if q.table != nil { + if err := q.afterDropTableHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *DropTableQuery) beforeDropTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeDropTableHook); ok { + if err := hook.BeforeDropTable(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *DropTableQuery) afterDropTableHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterDropTableHook); ok { + if err := hook.AfterDropTable(ctx, q); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/uptrace/bun/query_table_truncate.go b/vendor/github.com/uptrace/bun/query_table_truncate.go new file mode 100644 index 000000000..1e4bef7f6 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_table_truncate.go @@ -0,0 +1,121 @@ +package bun + +import ( + "context" + "database/sql" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type TruncateTableQuery struct { + baseQuery + cascadeQuery + + continueIdentity bool +} + +func NewTruncateTableQuery(db *DB) *TruncateTableQuery { + q := &TruncateTableQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + return q +} + +func (q *TruncateTableQuery) Conn(db IConn) *TruncateTableQuery { + q.setConn(db) + return q +} + +func (q *TruncateTableQuery) Model(model interface{}) *TruncateTableQuery { + q.setTableModel(model) + return q +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) Table(tables ...string) *TruncateTableQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *TruncateTableQuery) TableExpr(query string, args ...interface{}) *TruncateTableQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) ContinueIdentity() *TruncateTableQuery { + q.continueIdentity = true + return q +} + +func (q *TruncateTableQuery) Restrict() *TruncateTableQuery { + q.restrict = true + return q +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) AppendQuery( + fmter schema.Formatter, b []byte, +) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + + if !fmter.HasFeature(feature.TableTruncate) { + b = append(b, "DELETE FROM "...) + + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + + return b, nil + } + + b = append(b, "TRUNCATE TABLE "...) + + b, err = q.appendTables(fmter, b) + if err != nil { + return nil, err + } + + if q.db.features.Has(feature.TableIdentity) { + if q.continueIdentity { + b = append(b, " CONTINUE IDENTITY"...) + } else { + b = append(b, " RESTART IDENTITY"...) + } + } + + b = q.appendCascade(fmter, b) + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *TruncateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + res, err := q.exec(ctx, q, query) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/vendor/github.com/uptrace/bun/query_update.go b/vendor/github.com/uptrace/bun/query_update.go new file mode 100644 index 000000000..ea74e1419 --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_update.go @@ -0,0 +1,432 @@ +package bun + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/schema" +) + +type UpdateQuery struct { + whereBaseQuery + returningQuery + customValueQuery + setQuery + + omitZero bool +} + +func NewUpdateQuery(db *DB) *UpdateQuery { + q := &UpdateQuery{ + whereBaseQuery: whereBaseQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + }, + } + return q +} + +func (q *UpdateQuery) Conn(db IConn) *UpdateQuery { + q.setConn(db) + return q +} + +func (q *UpdateQuery) Model(model interface{}) *UpdateQuery { + q.setTableModel(model) + return q +} + +// Apply calls the fn passing the SelectQuery as an argument. +func (q *UpdateQuery) Apply(fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery { + return fn(q) +} + +func (q *UpdateQuery) With(name string, query schema.QueryAppender) *UpdateQuery { + q.addWith(name, query) + return q +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Table(tables ...string) *UpdateQuery { + for _, table := range tables { + q.addTable(schema.UnsafeIdent(table)) + } + return q +} + +func (q *UpdateQuery) TableExpr(query string, args ...interface{}) *UpdateQuery { + q.addTable(schema.SafeQuery(query, args)) + return q +} + +func (q *UpdateQuery) ModelTableExpr(query string, args ...interface{}) *UpdateQuery { + q.modelTable = schema.SafeQuery(query, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Column(columns ...string) *UpdateQuery { + for _, column := range columns { + q.addColumn(schema.UnsafeIdent(column)) + } + return q +} + +func (q *UpdateQuery) ExcludeColumn(columns ...string) *UpdateQuery { + q.excludeColumn(columns) + return q +} + +func (q *UpdateQuery) Set(query string, args ...interface{}) *UpdateQuery { + q.addSet(schema.SafeQuery(query, args)) + return q +} + +// Value overwrites model value for the column in INSERT and UPDATE queries. +func (q *UpdateQuery) Value(column string, value string, args ...interface{}) *UpdateQuery { + if q.table == nil { + q.err = errNilModel + return q + } + q.addValue(q.table, column, value, args) + return q +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) WherePK() *UpdateQuery { + q.flags = q.flags.Set(wherePKFlag) + return q +} + +func (q *UpdateQuery) Where(query string, args ...interface{}) *UpdateQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) + return q +} + +func (q *UpdateQuery) WhereOr(query string, args ...interface{}) *UpdateQuery { + q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) + return q +} + +func (q *UpdateQuery) WhereGroup(sep string, fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery { + saved := q.where + q.where = nil + + q = fn(q) + + where := q.where + q.where = saved + + q.addWhereGroup(sep, where) + + return q +} + +func (q *UpdateQuery) WhereDeleted() *UpdateQuery { + q.whereDeleted() + return q +} + +func (q *UpdateQuery) WhereAllWithDeleted() *UpdateQuery { + q.whereAllWithDeleted() + return q +} + +//------------------------------------------------------------------------------ + +// Returning adds a RETURNING clause to the query. +// +// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. +func (q *UpdateQuery) Returning(query string, args ...interface{}) *UpdateQuery { + q.addReturning(schema.SafeQuery(query, args)) + return q +} + +func (q *UpdateQuery) hasReturning() bool { + if !q.db.features.Has(feature.Returning) { + return false + } + return q.returningQuery.hasReturning() +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + fmter = formatterWithModel(fmter, q) + + withAlias := fmter.HasFeature(feature.UpdateMultiTable) + + b, err = q.appendWith(fmter, b) + if err != nil { + return nil, err + } + + b = append(b, "UPDATE "...) + + if withAlias { + b, err = q.appendTablesWithAlias(fmter, b) + } else { + b, err = q.appendFirstTableWithAlias(fmter, b) + } + if err != nil { + return nil, err + } + + b, err = q.mustAppendSet(fmter, b) + if err != nil { + return nil, err + } + + if !fmter.HasFeature(feature.UpdateMultiTable) { + b, err = q.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + } + + b, err = q.mustAppendWhere(fmter, b, withAlias) + if err != nil { + return nil, err + } + + if len(q.returning) > 0 { + b, err = q.appendReturning(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *UpdateQuery) mustAppendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) { + b = append(b, " SET "...) + + if len(q.set) > 0 { + return q.appendSet(fmter, b) + } + + if m, ok := q.model.(*mapModel); ok { + return m.appendSet(fmter, b), nil + } + + if q.tableModel == nil { + return nil, errNilModel + } + + switch model := q.tableModel.(type) { + case *structTableModel: + b, err = q.appendSetStruct(fmter, b, model) + if err != nil { + return nil, err + } + case *sliceTableModel: + return nil, errors.New("bun: to bulk Update, use CTE and VALUES") + default: + return nil, fmt.Errorf("bun: Update does not support %T", q.tableModel) + } + + return b, nil +} + +func (q *UpdateQuery) appendSetStruct( + fmter schema.Formatter, b []byte, model *structTableModel, +) ([]byte, error) { + fields, err := q.getDataFields() + if err != nil { + return nil, err + } + + isTemplate := fmter.IsNop() + pos := len(b) + for _, f := range fields { + if q.omitZero && f.NullZero && f.HasZeroValue(model.strct) { + continue + } + + if len(b) != pos { + b = append(b, ", "...) + pos = len(b) + } + + b = append(b, f.SQLName...) + b = append(b, " = "...) + + if isTemplate { + b = append(b, '?') + continue + } + + app, ok := q.modelValues[f.Name] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } else { + b = f.AppendValue(fmter, b, model.strct) + } + } + + for i, v := range q.extraValues { + if i > 0 || len(fields) > 0 { + b = append(b, ", "...) + } + + b = append(b, v.column...) + b = append(b, " = "...) + + b, err = v.value.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + return b, nil +} + +func (q *UpdateQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if !q.hasMultiTables() { + return b, nil + } + + b = append(b, " FROM "...) + + b, err = q.whereBaseQuery.appendOtherTables(fmter, b) + if err != nil { + return nil, err + } + + return b, nil +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Bulk() *UpdateQuery { + model, ok := q.model.(*sliceTableModel) + if !ok { + q.setErr(fmt.Errorf("bun: Bulk requires a slice, got %T", q.model)) + return q + } + + return q.With("_data", q.db.NewValues(model)). + Model(model). + TableExpr("_data"). + Set(q.updateSliceSet(model)). + Where(q.updateSliceWhere(model)) +} + +func (q *UpdateQuery) updateSliceSet(model *sliceTableModel) string { + var b []byte + for i, field := range model.table.DataFields { + if i > 0 { + b = append(b, ", "...) + } + if q.db.fmter.HasFeature(feature.UpdateMultiTable) { + b = append(b, model.table.SQLAlias...) + b = append(b, '.') + } + b = append(b, field.SQLName...) + b = append(b, " = _data."...) + b = append(b, field.SQLName...) + } + return internal.String(b) +} + +func (db *UpdateQuery) updateSliceWhere(model *sliceTableModel) string { + var b []byte + for i, pk := range model.table.PKs { + if i > 0 { + b = append(b, " AND "...) + } + b = append(b, model.table.SQLAlias...) + b = append(b, '.') + b = append(b, pk.SQLName...) + b = append(b, " = _data."...) + b = append(b, pk.SQLName...) + } + return internal.String(b) +} + +//------------------------------------------------------------------------------ + +func (q *UpdateQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + if q.table != nil { + if err := q.beforeUpdateHook(ctx); err != nil { + return nil, err + } + } + + queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) + if err != nil { + return nil, err + } + + query := internal.String(queryBytes) + + var res sql.Result + + if hasDest := len(dest) > 0; hasDest || q.hasReturning() { + model, err := q.getModel(dest) + if err != nil { + return nil, err + } + + res, err = q.scan(ctx, q, query, model, hasDest) + if err != nil { + return nil, err + } + } else { + res, err = q.exec(ctx, q, query) + if err != nil { + return nil, err + } + } + + if q.table != nil { + if err := q.afterUpdateHook(ctx); err != nil { + return nil, err + } + } + + return res, nil +} + +func (q *UpdateQuery) beforeUpdateHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(BeforeUpdateHook); ok { + if err := hook.BeforeUpdate(ctx, q); err != nil { + return err + } + } + return nil +} + +func (q *UpdateQuery) afterUpdateHook(ctx context.Context) error { + if hook, ok := q.table.ZeroIface.(AfterUpdateHook); ok { + if err := hook.AfterUpdate(ctx, q); err != nil { + return err + } + } + return nil +} + +// FQN returns a fully qualified column name. For MySQL, it returns the column name with +// the table alias. For other RDBMS, it returns just the column name. +func (q *UpdateQuery) FQN(name string) Ident { + if q.db.fmter.HasFeature(feature.UpdateMultiTable) { + return Ident(q.table.Alias + "." + name) + } + return Ident(name) +} diff --git a/vendor/github.com/uptrace/bun/query_values.go b/vendor/github.com/uptrace/bun/query_values.go new file mode 100644 index 000000000..323ac68ef --- /dev/null +++ b/vendor/github.com/uptrace/bun/query_values.go @@ -0,0 +1,198 @@ +package bun + +import ( + "fmt" + "reflect" + "strconv" + + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/schema" +) + +type ValuesQuery struct { + baseQuery + customValueQuery + + withOrder bool +} + +var _ schema.NamedArgAppender = (*ValuesQuery)(nil) + +func NewValuesQuery(db *DB, model interface{}) *ValuesQuery { + q := &ValuesQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + } + q.setTableModel(model) + return q +} + +func (q *ValuesQuery) Conn(db IConn) *ValuesQuery { + q.setConn(db) + return q +} + +func (q *ValuesQuery) WithOrder() *ValuesQuery { + q.withOrder = true + return q +} + +func (q *ValuesQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) { + switch name { + case "Columns": + bb, err := q.AppendColumns(fmter, b) + if err != nil { + q.setErr(err) + return b, true + } + return bb, true + } + return b, false +} + +// AppendColumns appends the table columns. It is used by CTE. +func (q *ValuesQuery) AppendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if q.model == nil { + return nil, errNilModel + } + + if q.tableModel != nil { + fields, err := q.getFields() + if err != nil { + return nil, err + } + + b = appendColumns(b, "", fields) + + if q.withOrder { + b = append(b, ", _order"...) + } + + return b, nil + } + + switch model := q.model.(type) { + case *mapSliceModel: + return model.appendColumns(fmter, b) + } + + return nil, fmt.Errorf("bun: Values does not support %T", q.model) +} + +func (q *ValuesQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + if q.err != nil { + return nil, q.err + } + if q.model == nil { + return nil, errNilModel + } + + fmter = formatterWithModel(fmter, q) + + if q.tableModel != nil { + fields, err := q.getFields() + if err != nil { + return nil, err + } + return q.appendQuery(fmter, b, fields) + } + + switch model := q.model.(type) { + case *mapSliceModel: + return model.appendValues(fmter, b) + } + + return nil, fmt.Errorf("bun: Values does not support %T", q.model) +} + +func (q *ValuesQuery) appendQuery( + fmter schema.Formatter, + b []byte, + fields []*schema.Field, +) (_ []byte, err error) { + b = append(b, "VALUES "...) + if q.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + + switch model := q.tableModel.(type) { + case *structTableModel: + b, err = q.appendValues(fmter, b, fields, model.strct) + if err != nil { + return nil, err + } + + if q.withOrder { + b = append(b, ", "...) + b = strconv.AppendInt(b, 0, 10) + } + case *sliceTableModel: + slice := model.slice + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + if i > 0 { + b = append(b, "), "...) + if q.db.features.Has(feature.ValuesRow) { + b = append(b, "ROW("...) + } else { + b = append(b, '(') + } + } + + b, err = q.appendValues(fmter, b, fields, slice.Index(i)) + if err != nil { + return nil, err + } + + if q.withOrder { + b = append(b, ", "...) + b = strconv.AppendInt(b, int64(i), 10) + } + } + default: + return nil, fmt.Errorf("bun: Values does not support %T", q.model) + } + + b = append(b, ')') + + return b, nil +} + +func (q *ValuesQuery) appendValues( + fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value, +) (_ []byte, err error) { + isTemplate := fmter.IsNop() + for i, f := range fields { + if i > 0 { + b = append(b, ", "...) + } + + app, ok := q.modelValues[f.Name] + if ok { + b, err = app.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + continue + } + + if isTemplate { + b = append(b, '?') + } else { + b = f.AppendValue(fmter, b, indirect(strct)) + } + + if fmter.HasFeature(feature.DoubleColonCast) { + b = append(b, "::"...) + b = append(b, f.UserSQLType...) + } + } + return b, nil +} diff --git a/vendor/github.com/uptrace/bun/schema/append.go b/vendor/github.com/uptrace/bun/schema/append.go new file mode 100644 index 000000000..68f7071c8 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/append.go @@ -0,0 +1,93 @@ +package schema + +import ( + "reflect" + "strconv" + "strings" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/internal" +) + +func FieldAppender(dialect Dialect, field *Field) AppenderFunc { + if field.Tag.HasOption("msgpack") { + return appendMsgpack + } + + switch strings.ToUpper(field.UserSQLType) { + case sqltype.JSON, sqltype.JSONB: + return AppendJSONValue + } + + return dialect.Appender(field.StructField.Type) +} + +func Append(fmter Formatter, b []byte, v interface{}, custom CustomAppender) []byte { + switch v := v.(type) { + case nil: + return dialect.AppendNull(b) + case bool: + return dialect.AppendBool(b, v) + case int: + return strconv.AppendInt(b, int64(v), 10) + case int32: + return strconv.AppendInt(b, int64(v), 10) + case int64: + return strconv.AppendInt(b, v, 10) + case uint: + return strconv.AppendUint(b, uint64(v), 10) + case uint32: + return strconv.AppendUint(b, uint64(v), 10) + case uint64: + return strconv.AppendUint(b, v, 10) + case float32: + return dialect.AppendFloat32(b, v) + case float64: + return dialect.AppendFloat64(b, v) + case string: + return dialect.AppendString(b, v) + case time.Time: + return dialect.AppendTime(b, v) + case []byte: + return dialect.AppendBytes(b, v) + case QueryAppender: + return AppendQueryAppender(fmter, b, v) + default: + vv := reflect.ValueOf(v) + if vv.Kind() == reflect.Ptr && vv.IsNil() { + return dialect.AppendNull(b) + } + appender := Appender(vv.Type(), custom) + return appender(fmter, b, vv) + } +} + +func appendMsgpack(fmter Formatter, b []byte, v reflect.Value) []byte { + hexEnc := internal.NewHexEncoder(b) + + enc := msgpack.GetEncoder() + defer msgpack.PutEncoder(enc) + + enc.Reset(hexEnc) + if err := enc.EncodeValue(v); err != nil { + return dialect.AppendError(b, err) + } + + if err := hexEnc.Close(); err != nil { + return dialect.AppendError(b, err) + } + + return hexEnc.Bytes() +} + +func AppendQueryAppender(fmter Formatter, b []byte, app QueryAppender) []byte { + bb, err := app.AppendQuery(fmter, b) + if err != nil { + return dialect.AppendError(b, err) + } + return bb +} diff --git a/vendor/github.com/uptrace/bun/schema/append_value.go b/vendor/github.com/uptrace/bun/schema/append_value.go new file mode 100644 index 000000000..0c4677069 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/append_value.go @@ -0,0 +1,237 @@ +package schema + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "net" + "reflect" + "strconv" + "time" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/extra/bunjson" + "github.com/uptrace/bun/internal" +) + +var ( + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + ipType = reflect.TypeOf((*net.IP)(nil)).Elem() + ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() + jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() + + driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem() +) + +type ( + AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte + CustomAppender func(typ reflect.Type) AppenderFunc +) + +var appenders = []AppenderFunc{ + reflect.Bool: AppendBoolValue, + reflect.Int: AppendIntValue, + reflect.Int8: AppendIntValue, + reflect.Int16: AppendIntValue, + reflect.Int32: AppendIntValue, + reflect.Int64: AppendIntValue, + reflect.Uint: AppendUintValue, + reflect.Uint8: AppendUintValue, + reflect.Uint16: AppendUintValue, + reflect.Uint32: AppendUintValue, + reflect.Uint64: AppendUintValue, + reflect.Uintptr: nil, + reflect.Float32: AppendFloat32Value, + reflect.Float64: AppendFloat64Value, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: AppendJSONValue, + reflect.Chan: nil, + reflect.Func: nil, + reflect.Interface: nil, + reflect.Map: AppendJSONValue, + reflect.Ptr: nil, + reflect.Slice: AppendJSONValue, + reflect.String: AppendStringValue, + reflect.Struct: AppendJSONValue, + reflect.UnsafePointer: nil, +} + +func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc { + switch typ { + case timeType: + return appendTimeValue + case ipType: + return appendIPValue + case ipNetType: + return appendIPNetValue + case jsonRawMessageType: + return appendJSONRawMessageValue + } + + if typ.Implements(queryAppenderType) { + return appendQueryAppenderValue + } + if typ.Implements(driverValuerType) { + return driverValueAppender(custom) + } + + kind := typ.Kind() + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(queryAppenderType) { + return addrAppender(appendQueryAppenderValue, custom) + } + if ptr.Implements(driverValuerType) { + return addrAppender(driverValueAppender(custom), custom) + } + } + + switch kind { + case reflect.Interface: + return ifaceAppenderFunc(typ, custom) + case reflect.Ptr: + return ptrAppenderFunc(typ, custom) + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return appendBytesValue + } + case reflect.Array: + if typ.Elem().Kind() == reflect.Uint8 { + return appendArrayBytesValue + } + } + + if custom != nil { + if fn := custom(typ); fn != nil { + return fn + } + } + return appenders[typ.Kind()] +} + +func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.IsNil() { + return dialect.AppendNull(b) + } + elem := v.Elem() + appender := Appender(elem.Type(), custom) + return appender(fmter, b, elem) + } +} + +func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc { + appender := Appender(typ.Elem(), custom) + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.IsNil() { + return dialect.AppendNull(b) + } + return appender(fmter, b, v.Elem()) + } +} + +func AppendBoolValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendBool(b, v.Bool()) +} + +func AppendIntValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendInt(b, v.Int(), 10) +} + +func AppendUintValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return strconv.AppendUint(b, v.Uint(), 10) +} + +func AppendFloat32Value(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendFloat32(b, float32(v.Float())) +} + +func AppendFloat64Value(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendFloat64(b, float64(v.Float())) +} + +func appendBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendBytes(b, v.Bytes()) +} + +func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte { + if v.CanAddr() { + return dialect.AppendBytes(b, v.Slice(0, v.Len()).Bytes()) + } + + tmp := make([]byte, v.Len()) + reflect.Copy(reflect.ValueOf(tmp), v) + b = dialect.AppendBytes(b, tmp) + return b +} + +func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return dialect.AppendString(b, v.String()) +} + +func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte { + bb, err := bunjson.Marshal(v.Interface()) + if err != nil { + return dialect.AppendError(b, err) + } + + if len(bb) > 0 && bb[len(bb)-1] == '\n' { + bb = bb[:len(bb)-1] + } + + return dialect.AppendJSON(b, bb) +} + +func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte { + tm := v.Interface().(time.Time) + return dialect.AppendTime(b, tm) +} + +func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte { + ip := v.Interface().(net.IP) + return dialect.AppendString(b, ip.String()) +} + +func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte { + ipnet := v.Interface().(net.IPNet) + return dialect.AppendString(b, ipnet.String()) +} + +func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte { + bytes := v.Bytes() + if bytes == nil { + return dialect.AppendNull(b) + } + return dialect.AppendString(b, internal.String(bytes)) +} + +func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte { + return AppendQueryAppender(fmter, b, v.Interface().(QueryAppender)) +} + +func driverValueAppender(custom CustomAppender) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + return appendDriverValue(fmter, b, v.Interface().(driver.Valuer), custom) + } +} + +func appendDriverValue(fmter Formatter, b []byte, v driver.Valuer, custom CustomAppender) []byte { + value, err := v.Value() + if err != nil { + return dialect.AppendError(b, err) + } + return Append(fmter, b, value, custom) +} + +func addrAppender(fn AppenderFunc, custom CustomAppender) AppenderFunc { + return func(fmter Formatter, b []byte, v reflect.Value) []byte { + if !v.CanAddr() { + err := fmt.Errorf("bun: Append(nonaddressable %T)", v.Interface()) + return dialect.AppendError(b, err) + } + return fn(fmter, b, v.Addr()) + } +} diff --git a/vendor/github.com/uptrace/bun/schema/dialect.go b/vendor/github.com/uptrace/bun/schema/dialect.go new file mode 100644 index 000000000..c50de715a --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/dialect.go @@ -0,0 +1,99 @@ +package schema + +import ( + "database/sql" + "reflect" + "sync" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" +) + +type Dialect interface { + Init(db *sql.DB) + + Name() dialect.Name + Features() feature.Feature + + Tables() *Tables + OnTable(table *Table) + + IdentQuote() byte + Append(fmter Formatter, b []byte, v interface{}) []byte + Appender(typ reflect.Type) AppenderFunc + FieldAppender(field *Field) AppenderFunc + Scanner(typ reflect.Type) ScannerFunc +} + +//------------------------------------------------------------------------------ + +type nopDialect struct { + tables *Tables + features feature.Feature + + appenderMap sync.Map + scannerMap sync.Map +} + +func newNopDialect() *nopDialect { + d := new(nopDialect) + d.tables = NewTables(d) + d.features = feature.Returning + return d +} + +func (d *nopDialect) Init(*sql.DB) {} + +func (d *nopDialect) Name() dialect.Name { + return dialect.Invalid +} + +func (d *nopDialect) Features() feature.Feature { + return d.features +} + +func (d *nopDialect) Tables() *Tables { + return d.tables +} + +func (d *nopDialect) OnField(field *Field) {} + +func (d *nopDialect) OnTable(table *Table) {} + +func (d *nopDialect) IdentQuote() byte { + return '"' +} + +func (d *nopDialect) Append(fmter Formatter, b []byte, v interface{}) []byte { + return Append(fmter, b, v, nil) +} + +func (d *nopDialect) Appender(typ reflect.Type) AppenderFunc { + if v, ok := d.appenderMap.Load(typ); ok { + return v.(AppenderFunc) + } + + fn := Appender(typ, nil) + + if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok { + return v.(AppenderFunc) + } + return fn +} + +func (d *nopDialect) FieldAppender(field *Field) AppenderFunc { + return FieldAppender(d, field) +} + +func (d *nopDialect) Scanner(typ reflect.Type) ScannerFunc { + if v, ok := d.scannerMap.Load(typ); ok { + return v.(ScannerFunc) + } + + fn := Scanner(typ) + + if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok { + return v.(ScannerFunc) + } + return fn +} diff --git a/vendor/github.com/uptrace/bun/schema/field.go b/vendor/github.com/uptrace/bun/schema/field.go new file mode 100644 index 000000000..1e069b82f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/field.go @@ -0,0 +1,117 @@ +package schema + +import ( + "fmt" + "reflect" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/internal/tagparser" +) + +type Field struct { + StructField reflect.StructField + + Tag tagparser.Tag + IndirectType reflect.Type + Index []int + + Name string // SQL name, .e.g. id + SQLName Safe // escaped SQL name, e.g. "id" + GoName string // struct field name, e.g. Id + + DiscoveredSQLType string + UserSQLType string + CreateTableSQLType string + SQLDefault string + + OnDelete string + OnUpdate string + + IsPK bool + NotNull bool + NullZero bool + AutoIncrement bool + + Append AppenderFunc + Scan ScannerFunc + IsZero IsZeroerFunc +} + +func (f *Field) String() string { + return f.Name +} + +func (f *Field) Clone() *Field { + cp := *f + cp.Index = cp.Index[:len(f.Index):len(f.Index)] + return &cp +} + +func (f *Field) Value(strct reflect.Value) reflect.Value { + return fieldByIndexAlloc(strct, f.Index) +} + +func (f *Field) HasZeroValue(v reflect.Value) bool { + for _, idx := range f.Index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return true + } + v = v.Elem() + } + v = v.Field(idx) + } + return f.IsZero(v) +} + +func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []byte { + fv, ok := fieldByIndex(strct, f.Index) + if !ok { + return dialect.AppendNull(b) + } + + if f.NullZero && f.IsZero(fv) { + return dialect.AppendNull(b) + } + if f.Append == nil { + panic(fmt.Errorf("bun: AppendValue(unsupported %s)", fv.Type())) + } + return f.Append(fmter, b, fv) +} + +func (f *Field) ScanWithCheck(fv reflect.Value, src interface{}) error { + if f.Scan == nil { + return fmt.Errorf("bun: Scan(unsupported %s)", f.IndirectType) + } + return f.Scan(fv, src) +} + +func (f *Field) ScanValue(strct reflect.Value, src interface{}) error { + if src == nil { + if fv, ok := fieldByIndex(strct, f.Index); ok { + return f.ScanWithCheck(fv, src) + } + return nil + } + + fv := fieldByIndexAlloc(strct, f.Index) + return f.ScanWithCheck(fv, src) +} + +func (f *Field) markAsPK() { + f.IsPK = true + f.NotNull = true + f.NullZero = true +} + +func indexEqual(ind1, ind2 []int) bool { + if len(ind1) != len(ind2) { + return false + } + for i, ind := range ind1 { + if ind != ind2[i] { + return false + } + } + return true +} diff --git a/vendor/github.com/uptrace/bun/schema/formatter.go b/vendor/github.com/uptrace/bun/schema/formatter.go new file mode 100644 index 000000000..7b26fbaca --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/formatter.go @@ -0,0 +1,248 @@ +package schema + +import ( + "reflect" + "strconv" + "strings" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/parser" +) + +var nopFormatter = Formatter{ + dialect: newNopDialect(), +} + +type Formatter struct { + dialect Dialect + args *namedArgList +} + +func NewFormatter(dialect Dialect) Formatter { + return Formatter{ + dialect: dialect, + } +} + +func NewNopFormatter() Formatter { + return nopFormatter +} + +func (f Formatter) IsNop() bool { + return f.dialect.Name() == dialect.Invalid +} + +func (f Formatter) Dialect() Dialect { + return f.dialect +} + +func (f Formatter) IdentQuote() byte { + return f.dialect.IdentQuote() +} + +func (f Formatter) AppendIdent(b []byte, ident string) []byte { + return dialect.AppendIdent(b, ident, f.IdentQuote()) +} + +func (f Formatter) AppendValue(b []byte, v reflect.Value) []byte { + if v.Kind() == reflect.Ptr && v.IsNil() { + return dialect.AppendNull(b) + } + appender := f.dialect.Appender(v.Type()) + return appender(f, b, v) +} + +func (f Formatter) HasFeature(feature feature.Feature) bool { + return f.dialect.Features().Has(feature) +} + +func (f Formatter) WithArg(arg NamedArgAppender) Formatter { + return Formatter{ + dialect: f.dialect, + args: f.args.WithArg(arg), + } +} + +func (f Formatter) WithNamedArg(name string, value interface{}) Formatter { + return Formatter{ + dialect: f.dialect, + args: f.args.WithArg(&namedArg{name: name, value: value}), + } +} + +func (f Formatter) FormatQuery(query string, args ...interface{}) string { + if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 { + return query + } + return internal.String(f.AppendQuery(nil, query, args...)) +} + +func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []byte { + if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 { + return append(dst, query...) + } + return f.append(dst, parser.NewString(query), args) +} + +func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte { + var namedArgs NamedArgAppender + if len(args) == 1 { + var ok bool + namedArgs, ok = args[0].(NamedArgAppender) + if !ok { + namedArgs, _ = newStructArgs(f, args[0]) + } + } + + var argIndex int + for p.Valid() { + b, ok := p.ReadSep('?') + if !ok { + dst = append(dst, b...) + continue + } + if len(b) > 0 && b[len(b)-1] == '\\' { + dst = append(dst, b[:len(b)-1]...) + dst = append(dst, '?') + continue + } + dst = append(dst, b...) + + name, numeric := p.ReadIdentifier() + if name != "" { + if numeric { + idx, err := strconv.Atoi(name) + if err != nil { + goto restore_arg + } + + if idx >= len(args) { + goto restore_arg + } + + dst = f.appendArg(dst, args[idx]) + continue + } + + if namedArgs != nil { + dst, ok = namedArgs.AppendNamedArg(f, dst, name) + if ok { + continue + } + } + + dst, ok = f.args.AppendNamedArg(f, dst, name) + if ok { + continue + } + + restore_arg: + dst = append(dst, '?') + dst = append(dst, name...) + continue + } + + if argIndex >= len(args) { + dst = append(dst, '?') + continue + } + + arg := args[argIndex] + argIndex++ + + dst = f.appendArg(dst, arg) + } + + return dst +} + +func (f Formatter) appendArg(b []byte, arg interface{}) []byte { + switch arg := arg.(type) { + case QueryAppender: + bb, err := arg.AppendQuery(f, b) + if err != nil { + return dialect.AppendError(b, err) + } + return bb + default: + return f.dialect.Append(f, b, arg) + } +} + +//------------------------------------------------------------------------------ + +type NamedArgAppender interface { + AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) +} + +//------------------------------------------------------------------------------ + +type namedArgList struct { + arg NamedArgAppender + next *namedArgList +} + +func (l *namedArgList) WithArg(arg NamedArgAppender) *namedArgList { + return &namedArgList{ + arg: arg, + next: l, + } +} + +func (l *namedArgList) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + for l != nil && l.arg != nil { + if b, ok := l.arg.AppendNamedArg(fmter, b, name); ok { + return b, true + } + l = l.next + } + return b, false +} + +//------------------------------------------------------------------------------ + +type namedArg struct { + name string + value interface{} +} + +var _ NamedArgAppender = (*namedArg)(nil) + +func (a *namedArg) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + if a.name == name { + return fmter.appendArg(b, a.value), true + } + return b, false +} + +//------------------------------------------------------------------------------ + +var _ NamedArgAppender = (*structArgs)(nil) + +type structArgs struct { + table *Table + strct reflect.Value +} + +func newStructArgs(fmter Formatter, strct interface{}) (*structArgs, bool) { + v := reflect.ValueOf(strct) + if !v.IsValid() { + return nil, false + } + + v = reflect.Indirect(v) + if v.Kind() != reflect.Struct { + return nil, false + } + + return &structArgs{ + table: fmter.Dialect().Tables().Get(v.Type()), + strct: v, + }, true +} + +func (m *structArgs) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) { + return m.table.AppendNamedArg(fmter, b, name, m.strct) +} diff --git a/vendor/github.com/uptrace/bun/schema/hook.go b/vendor/github.com/uptrace/bun/schema/hook.go new file mode 100644 index 000000000..5391981d5 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/hook.go @@ -0,0 +1,20 @@ +package schema + +import ( + "context" + "reflect" +) + +type BeforeScanHook interface { + BeforeScan(context.Context) error +} + +var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem() + +//------------------------------------------------------------------------------ + +type AfterScanHook interface { + AfterScan(context.Context) error +} + +var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem() diff --git a/vendor/github.com/uptrace/bun/schema/relation.go b/vendor/github.com/uptrace/bun/schema/relation.go new file mode 100644 index 000000000..8d1baeb3f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/relation.go @@ -0,0 +1,32 @@ +package schema + +import ( + "fmt" +) + +const ( + InvalidRelation = iota + HasOneRelation + BelongsToRelation + HasManyRelation + ManyToManyRelation +) + +type Relation struct { + Type int + Field *Field + JoinTable *Table + BaseFields []*Field + JoinFields []*Field + + PolymorphicField *Field + PolymorphicValue string + + M2MTable *Table + M2MBaseFields []*Field + M2MJoinFields []*Field +} + +func (r *Relation) String() string { + return fmt.Sprintf("relation=%s", r.Field.GoName) +} diff --git a/vendor/github.com/uptrace/bun/schema/scan.go b/vendor/github.com/uptrace/bun/schema/scan.go new file mode 100644 index 000000000..0e66a860f --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/scan.go @@ -0,0 +1,392 @@ +package schema + +import ( + "bytes" + "database/sql" + "fmt" + "net" + "reflect" + "strconv" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "github.com/uptrace/bun/extra/bunjson" + "github.com/uptrace/bun/internal" +) + +var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +type ScannerFunc func(dest reflect.Value, src interface{}) error + +var scanners = []ScannerFunc{ + reflect.Bool: scanBool, + reflect.Int: scanInt64, + reflect.Int8: scanInt64, + reflect.Int16: scanInt64, + reflect.Int32: scanInt64, + reflect.Int64: scanInt64, + reflect.Uint: scanUint64, + reflect.Uint8: scanUint64, + reflect.Uint16: scanUint64, + reflect.Uint32: scanUint64, + reflect.Uint64: scanUint64, + reflect.Uintptr: scanUint64, + reflect.Float32: scanFloat64, + reflect.Float64: scanFloat64, + reflect.Complex64: nil, + reflect.Complex128: nil, + reflect.Array: nil, + reflect.Chan: nil, + reflect.Func: nil, + reflect.Map: scanJSON, + reflect.Ptr: nil, + reflect.Slice: scanJSON, + reflect.String: scanString, + reflect.Struct: scanJSON, + reflect.UnsafePointer: nil, +} + +func FieldScanner(dialect Dialect, field *Field) ScannerFunc { + if field.Tag.HasOption("msgpack") { + return scanMsgpack + } + if field.Tag.HasOption("json_use_number") { + return scanJSONUseNumber + } + return dialect.Scanner(field.StructField.Type) +} + +func Scanner(typ reflect.Type) ScannerFunc { + kind := typ.Kind() + + if kind == reflect.Ptr { + if fn := Scanner(typ.Elem()); fn != nil { + return ptrScanner(fn) + } + } + + if typ.Implements(scannerType) { + return scanScanner + } + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(scannerType) { + return addrScanner(scanScanner) + } + } + + switch typ { + case timeType: + return scanTime + case ipType: + return scanIP + case ipNetType: + return scanIPNet + case jsonRawMessageType: + return scanJSONRawMessage + } + + return scanners[kind] +} + +func scanBool(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetBool(false) + return nil + case bool: + dest.SetBool(src) + return nil + case int64: + dest.SetBool(src != 0) + return nil + case []byte: + if len(src) == 1 { + dest.SetBool(src[0] != '0') + return nil + } + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanInt64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetInt(0) + return nil + case int64: + dest.SetInt(src) + return nil + case uint64: + dest.SetInt(int64(src)) + return nil + case []byte: + n, err := strconv.ParseInt(internal.String(src), 10, 64) + if err != nil { + return err + } + dest.SetInt(n) + return nil + case string: + n, err := strconv.ParseInt(src, 10, 64) + if err != nil { + return err + } + dest.SetInt(n) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanUint64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetUint(0) + return nil + case uint64: + dest.SetUint(src) + return nil + case int64: + dest.SetUint(uint64(src)) + return nil + case []byte: + n, err := strconv.ParseUint(internal.String(src), 10, 64) + if err != nil { + return err + } + dest.SetUint(n) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanFloat64(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetFloat(0) + return nil + case float64: + dest.SetFloat(src) + return nil + case []byte: + f, err := strconv.ParseFloat(internal.String(src), 64) + if err != nil { + return err + } + dest.SetFloat(f) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanString(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + dest.SetString("") + return nil + case string: + dest.SetString(src) + return nil + case []byte: + dest.SetString(string(src)) + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanTime(dest reflect.Value, src interface{}) error { + switch src := src.(type) { + case nil: + destTime := dest.Addr().Interface().(*time.Time) + *destTime = time.Time{} + return nil + case time.Time: + destTime := dest.Addr().Interface().(*time.Time) + *destTime = src + return nil + case string: + srcTime, err := internal.ParseTime(src) + if err != nil { + return err + } + destTime := dest.Addr().Interface().(*time.Time) + *destTime = srcTime + return nil + case []byte: + srcTime, err := internal.ParseTime(internal.String(src)) + if err != nil { + return err + } + destTime := dest.Addr().Interface().(*time.Time) + *destTime = srcTime + return nil + } + return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type()) +} + +func scanScanner(dest reflect.Value, src interface{}) error { + return dest.Interface().(sql.Scanner).Scan(src) +} + +func scanMsgpack(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dec := msgpack.GetDecoder() + defer msgpack.PutDecoder(dec) + + dec.Reset(bytes.NewReader(b)) + return dec.DecodeValue(dest) +} + +func scanJSON(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + return bunjson.Unmarshal(b, dest.Addr().Interface()) +} + +func scanJSONUseNumber(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dec := bunjson.NewDecoder(bytes.NewReader(b)) + dec.UseNumber() + return dec.Decode(dest.Addr().Interface()) +} + +func scanIP(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + ip := net.ParseIP(internal.String(b)) + if ip == nil { + return fmt.Errorf("bun: invalid ip: %q", b) + } + + ptr := dest.Addr().Interface().(*net.IP) + *ptr = ip + + return nil +} + +func scanIPNet(dest reflect.Value, src interface{}) error { + if src == nil { + return scanNull(dest) + } + + b, err := toBytes(src) + if err != nil { + return err + } + + _, ipnet, err := net.ParseCIDR(internal.String(b)) + if err != nil { + return err + } + + ptr := dest.Addr().Interface().(*net.IPNet) + *ptr = *ipnet + + return nil +} + +func scanJSONRawMessage(dest reflect.Value, src interface{}) error { + if src == nil { + dest.SetBytes(nil) + return nil + } + + b, err := toBytes(src) + if err != nil { + return err + } + + dest.SetBytes(b) + return nil +} + +func addrScanner(fn ScannerFunc) ScannerFunc { + return func(dest reflect.Value, src interface{}) error { + if !dest.CanAddr() { + return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface()) + } + return fn(dest.Addr(), src) + } +} + +func toBytes(src interface{}) ([]byte, error) { + switch src := src.(type) { + case string: + return internal.Bytes(src), nil + case []byte: + return src, nil + default: + return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src) + } +} + +func ptrScanner(fn ScannerFunc) ScannerFunc { + return func(dest reflect.Value, src interface{}) error { + if src == nil { + if !dest.CanAddr() { + if dest.IsNil() { + return nil + } + return fn(dest.Elem(), src) + } + + if !dest.IsNil() { + dest.Set(reflect.New(dest.Type().Elem())) + } + return nil + } + + if dest.IsNil() { + dest.Set(reflect.New(dest.Type().Elem())) + } + return fn(dest.Elem(), src) + } +} + +func scanNull(dest reflect.Value) error { + if nilable(dest.Kind()) && dest.IsNil() { + return nil + } + dest.Set(reflect.New(dest.Type()).Elem()) + return nil +} + +func nilable(kind reflect.Kind) bool { + switch kind { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return true + } + return false +} diff --git a/vendor/github.com/uptrace/bun/schema/sqlfmt.go b/vendor/github.com/uptrace/bun/schema/sqlfmt.go new file mode 100644 index 000000000..7b538cd0c --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/sqlfmt.go @@ -0,0 +1,76 @@ +package schema + +type QueryAppender interface { + AppendQuery(fmter Formatter, b []byte) ([]byte, error) +} + +type ColumnsAppender interface { + AppendColumns(fmter Formatter, b []byte) ([]byte, error) +} + +//------------------------------------------------------------------------------ + +// Safe represents a safe SQL query. +type Safe string + +var _ QueryAppender = (*Safe)(nil) + +func (s Safe) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + return append(b, s...), nil +} + +//------------------------------------------------------------------------------ + +// Ident represents a SQL identifier, for example, table or column name. +type Ident string + +var _ QueryAppender = (*Ident)(nil) + +func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + return fmter.AppendIdent(b, string(s)), nil +} + +//------------------------------------------------------------------------------ + +type QueryWithArgs struct { + Query string + Args []interface{} +} + +var _ QueryAppender = QueryWithArgs{} + +func SafeQuery(query string, args []interface{}) QueryWithArgs { + if query != "" && args == nil { + args = make([]interface{}, 0) + } + return QueryWithArgs{Query: query, Args: args} +} + +func UnsafeIdent(ident string) QueryWithArgs { + return QueryWithArgs{Query: ident} +} + +func (q QueryWithArgs) IsZero() bool { + return q.Query == "" && q.Args == nil +} + +func (q QueryWithArgs) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + if q.Args == nil { + return fmter.AppendIdent(b, q.Query), nil + } + return fmter.AppendQuery(b, q.Query, q.Args...), nil +} + +//------------------------------------------------------------------------------ + +type QueryWithSep struct { + QueryWithArgs + Sep string +} + +func SafeQueryWithSep(query string, args []interface{}, sep string) QueryWithSep { + return QueryWithSep{ + QueryWithArgs: SafeQuery(query, args), + Sep: sep, + } +} diff --git a/vendor/github.com/uptrace/bun/schema/sqltype.go b/vendor/github.com/uptrace/bun/schema/sqltype.go new file mode 100644 index 000000000..560f695c2 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/sqltype.go @@ -0,0 +1,129 @@ +package schema + +import ( + "bytes" + "database/sql" + "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/internal" +) + +var ( + bunNullTimeType = reflect.TypeOf((*NullTime)(nil)).Elem() + nullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem() + nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() + nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() + nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() + nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() +) + +var sqlTypes = []string{ + reflect.Bool: sqltype.Boolean, + reflect.Int: sqltype.BigInt, + reflect.Int8: sqltype.SmallInt, + reflect.Int16: sqltype.SmallInt, + reflect.Int32: sqltype.Integer, + reflect.Int64: sqltype.BigInt, + reflect.Uint: sqltype.BigInt, + reflect.Uint8: sqltype.SmallInt, + reflect.Uint16: sqltype.SmallInt, + reflect.Uint32: sqltype.Integer, + reflect.Uint64: sqltype.BigInt, + reflect.Uintptr: sqltype.BigInt, + reflect.Float32: sqltype.Real, + reflect.Float64: sqltype.DoublePrecision, + reflect.Complex64: "", + reflect.Complex128: "", + reflect.Array: "", + reflect.Chan: "", + reflect.Func: "", + reflect.Interface: "", + reflect.Map: sqltype.VarChar, + reflect.Ptr: "", + reflect.Slice: sqltype.VarChar, + reflect.String: sqltype.VarChar, + reflect.Struct: sqltype.VarChar, + reflect.UnsafePointer: "", +} + +func DiscoverSQLType(typ reflect.Type) string { + switch typ { + case timeType, nullTimeType, bunNullTimeType: + return sqltype.Timestamp + case nullBoolType: + return sqltype.Boolean + case nullFloatType: + return sqltype.DoublePrecision + case nullIntType: + return sqltype.BigInt + case nullStringType: + return sqltype.VarChar + } + return sqlTypes[typ.Kind()] +} + +//------------------------------------------------------------------------------ + +var jsonNull = []byte("null") + +// NullTime is a time.Time wrapper that marshals zero time as JSON null and SQL NULL. +type NullTime struct { + time.Time +} + +var ( + _ json.Marshaler = (*NullTime)(nil) + _ json.Unmarshaler = (*NullTime)(nil) + _ sql.Scanner = (*NullTime)(nil) + _ QueryAppender = (*NullTime)(nil) +) + +func (tm NullTime) MarshalJSON() ([]byte, error) { + if tm.IsZero() { + return jsonNull, nil + } + return tm.Time.MarshalJSON() +} + +func (tm *NullTime) UnmarshalJSON(b []byte) error { + if bytes.Equal(b, jsonNull) { + tm.Time = time.Time{} + return nil + } + return tm.Time.UnmarshalJSON(b) +} + +func (tm NullTime) AppendQuery(fmter Formatter, b []byte) ([]byte, error) { + if tm.IsZero() { + return dialect.AppendNull(b), nil + } + return dialect.AppendTime(b, tm.Time), nil +} + +func (tm *NullTime) Scan(src interface{}) error { + if src == nil { + tm.Time = time.Time{} + return nil + } + + switch src := src.(type) { + case []byte: + newtm, err := internal.ParseTime(internal.String(src)) + if err != nil { + return err + } + + tm.Time = newtm + return nil + case time.Time: + tm.Time = src + return nil + default: + return fmt.Errorf("bun: can't scan %#v into NullTime", src) + } +} diff --git a/vendor/github.com/uptrace/bun/schema/table.go b/vendor/github.com/uptrace/bun/schema/table.go new file mode 100644 index 000000000..eca18b781 --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/table.go @@ -0,0 +1,948 @@ +package schema + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + "sync" + "time" + + "github.com/jinzhu/inflection" + + "github.com/uptrace/bun/internal" + "github.com/uptrace/bun/internal/tagparser" +) + +const ( + beforeScanHookFlag internal.Flag = 1 << iota + afterScanHookFlag +) + +var ( + baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem() + tableNameInflector = inflection.Plural +) + +type BaseModel struct{} + +// SetTableNameInflector overrides the default func that pluralizes +// model name to get table name, e.g. my_article becomes my_articles. +func SetTableNameInflector(fn func(string) string) { + tableNameInflector = fn +} + +// Table represents a SQL table created from Go struct. +type Table struct { + dialect Dialect + + Type reflect.Type + ZeroValue reflect.Value // reflect.Struct + ZeroIface interface{} // struct pointer + + TypeName string + ModelName string + + Name string + SQLName Safe + SQLNameForSelects Safe + Alias string + SQLAlias Safe + + Fields []*Field // PKs + DataFields + PKs []*Field + DataFields []*Field + + fieldsMapMu sync.RWMutex + FieldMap map[string]*Field + + Relations map[string]*Relation + Unique map[string][]*Field + + SoftDeleteField *Field + UpdateSoftDeleteField func(fv reflect.Value) error + + allFields []*Field // read only + skippedFields []*Field + + flags internal.Flag +} + +func newTable(dialect Dialect, typ reflect.Type) *Table { + t := new(Table) + t.dialect = dialect + t.Type = typ + t.ZeroValue = reflect.New(t.Type).Elem() + t.ZeroIface = reflect.New(t.Type).Interface() + t.TypeName = internal.ToExported(t.Type.Name()) + t.ModelName = internal.Underscore(t.Type.Name()) + tableName := tableNameInflector(t.ModelName) + t.setName(tableName) + t.Alias = t.ModelName + t.SQLAlias = t.quoteIdent(t.ModelName) + + hooks := []struct { + typ reflect.Type + flag internal.Flag + }{ + {beforeScanHookType, beforeScanHookFlag}, + {afterScanHookType, afterScanHookFlag}, + } + + typ = reflect.PtrTo(t.Type) + for _, hook := range hooks { + if typ.Implements(hook.typ) { + t.flags = t.flags.Set(hook.flag) + } + } + + return t +} + +func (t *Table) init1() { + t.initFields() +} + +func (t *Table) init2() { + t.initInlines() + t.initRelations() + t.skippedFields = nil +} + +func (t *Table) setName(name string) { + t.Name = name + t.SQLName = t.quoteIdent(name) + t.SQLNameForSelects = t.quoteIdent(name) + if t.SQLAlias == "" { + t.Alias = name + t.SQLAlias = t.quoteIdent(name) + } +} + +func (t *Table) String() string { + return "model=" + t.TypeName +} + +func (t *Table) CheckPKs() error { + if len(t.PKs) == 0 { + return fmt.Errorf("bun: %s does not have primary keys", t) + } + return nil +} + +func (t *Table) addField(field *Field) { + t.Fields = append(t.Fields, field) + if field.IsPK { + t.PKs = append(t.PKs, field) + } else { + t.DataFields = append(t.DataFields, field) + } + t.FieldMap[field.Name] = field +} + +func (t *Table) removeField(field *Field) { + t.Fields = removeField(t.Fields, field) + if field.IsPK { + t.PKs = removeField(t.PKs, field) + } else { + t.DataFields = removeField(t.DataFields, field) + } + delete(t.FieldMap, field.Name) +} + +func (t *Table) fieldWithLock(name string) *Field { + t.fieldsMapMu.RLock() + field := t.FieldMap[name] + t.fieldsMapMu.RUnlock() + return field +} + +func (t *Table) HasField(name string) bool { + _, ok := t.FieldMap[name] + return ok +} + +func (t *Table) Field(name string) (*Field, error) { + field, ok := t.FieldMap[name] + if !ok { + return nil, fmt.Errorf("bun: %s does not have column=%s", t, name) + } + return field, nil +} + +func (t *Table) fieldByGoName(name string) *Field { + for _, f := range t.allFields { + if f.GoName == name { + return f + } + } + return nil +} + +func (t *Table) initFields() { + t.Fields = make([]*Field, 0, t.Type.NumField()) + t.FieldMap = make(map[string]*Field, t.Type.NumField()) + t.addFields(t.Type, nil) + + if len(t.PKs) > 0 { + return + } + for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} { + if field, ok := t.FieldMap[name]; ok { + field.markAsPK() + t.PKs = []*Field{field} + t.DataFields = removeField(t.DataFields, field) + break + } + } + if len(t.PKs) == 1 { + switch t.PKs[0].IndirectType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + t.PKs[0].AutoIncrement = true + } + } +} + +func (t *Table) addFields(typ reflect.Type, baseIndex []int) { + for i := 0; i < typ.NumField(); i++ { + f := typ.Field(i) + + // Make a copy so slice is not shared between fields. + index := make([]int, len(baseIndex)) + copy(index, baseIndex) + + if f.Anonymous { + if f.Tag.Get("bun") == "-" { + continue + } + if f.Name == "BaseModel" && f.Type == baseModelType { + if len(index) == 0 { + t.processBaseModelField(f) + } + continue + } + + fieldType := indirectType(f.Type) + if fieldType.Kind() != reflect.Struct { + continue + } + t.addFields(fieldType, append(index, f.Index...)) + + tag := tagparser.Parse(f.Tag.Get("bun")) + if _, inherit := tag.Options["inherit"]; inherit { + embeddedTable := t.dialect.Tables().Ref(fieldType) + t.TypeName = embeddedTable.TypeName + t.SQLName = embeddedTable.SQLName + t.SQLNameForSelects = embeddedTable.SQLNameForSelects + t.Alias = embeddedTable.Alias + t.SQLAlias = embeddedTable.SQLAlias + t.ModelName = embeddedTable.ModelName + } + + continue + } + + field := t.newField(f, index) + if field != nil { + t.addField(field) + } + } +} + +func (t *Table) processBaseModelField(f reflect.StructField) { + tag := tagparser.Parse(f.Tag.Get("bun")) + + if isKnownTableOption(tag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, tag.Name, + ) + } + + for name := range tag.Options { + if !isKnownTableOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + if tag.Name != "" { + t.setName(tag.Name) + } + + if s, ok := tag.Options["select"]; ok { + t.SQLNameForSelects = t.quoteTableName(s) + } + + if s, ok := tag.Options["alias"]; ok { + t.Alias = s + t.SQLAlias = t.quoteIdent(s) + } +} + +//nolint +func (t *Table) newField(f reflect.StructField, index []int) *Field { + tag := tagparser.Parse(f.Tag.Get("bun")) + + if f.PkgPath != "" { + return nil + } + + sqlName := internal.Underscore(f.Name) + + if tag.Name != sqlName && isKnownFieldOption(tag.Name) { + internal.Warn.Printf( + "%s.%s tag name %q is also an option name; is it a mistake?", + t.TypeName, f.Name, tag.Name, + ) + } + + for name := range tag.Options { + if !isKnownFieldOption(name) { + internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) + } + } + + skip := tag.Name == "-" + if !skip && tag.Name != "" { + sqlName = tag.Name + } + + index = append(index, f.Index...) + if field := t.fieldWithLock(sqlName); field != nil { + if indexEqual(field.Index, index) { + return field + } + t.removeField(field) + } + + field := &Field{ + StructField: f, + + Tag: tag, + IndirectType: indirectType(f.Type), + Index: index, + + Name: sqlName, + GoName: f.Name, + SQLName: t.quoteIdent(sqlName), + } + + field.NotNull = tag.HasOption("notnull") + field.NullZero = tag.HasOption("nullzero") + field.AutoIncrement = tag.HasOption("autoincrement") + if tag.HasOption("pk") { + field.markAsPK() + } + if tag.HasOption("allowzero") { + if tag.HasOption("nullzero") { + internal.Warn.Printf( + "%s.%s: nullzero and allowzero options are mutually exclusive", + t.TypeName, f.Name, + ) + } + field.NullZero = false + } + + if v, ok := tag.Options["unique"]; ok { + // Split the value by comma, this will allow multiple names to be specified. + // We can use this to create multiple named unique constraints where a single column + // might be included in multiple constraints. + for _, uniqueName := range strings.Split(v, ",") { + if t.Unique == nil { + t.Unique = make(map[string][]*Field) + } + t.Unique[uniqueName] = append(t.Unique[uniqueName], field) + } + } + if s, ok := tag.Options["default"]; ok { + field.SQLDefault = s + } + if s, ok := field.Tag.Options["type"]; ok { + field.UserSQLType = s + } + field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType) + field.Append = t.dialect.FieldAppender(field) + field.Scan = FieldScanner(t.dialect, field) + field.IsZero = FieldZeroChecker(field) + + if v, ok := tag.Options["alt"]; ok { + t.FieldMap[v] = field + } + + t.allFields = append(t.allFields, field) + if skip { + t.skippedFields = append(t.skippedFields, field) + t.FieldMap[field.Name] = field + return nil + } + + if _, ok := tag.Options["soft_delete"]; ok { + field.NullZero = true + t.SoftDeleteField = field + t.UpdateSoftDeleteField = softDeleteFieldUpdater(field) + } + + return field +} + +func (t *Table) initInlines() { + for _, f := range t.skippedFields { + if f.IndirectType.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +//--------------------------------------------------------------------------------------- + +func (t *Table) initRelations() { + for i := 0; i < len(t.Fields); { + f := t.Fields[i] + if t.tryRelation(f) { + t.Fields = removeField(t.Fields, f) + t.DataFields = removeField(t.DataFields, f) + } else { + i++ + } + + if f.IndirectType.Kind() == reflect.Struct { + t.inlineFields(f, nil) + } + } +} + +func (t *Table) tryRelation(field *Field) bool { + if rel, ok := field.Tag.Options["rel"]; ok { + t.initRelation(field, rel) + return true + } + if field.Tag.HasOption("m2m") { + t.addRelation(t.m2mRelation(field)) + return true + } + + if field.Tag.HasOption("join") { + internal.Warn.Printf( + `%s.%s option "join" requires a relation type`, + t.TypeName, field.GoName, + ) + } + + return false +} + +func (t *Table) initRelation(field *Field, rel string) { + switch rel { + case "belongs-to": + t.addRelation(t.belongsToRelation(field)) + case "has-one": + t.addRelation(t.hasOneRelation(field)) + case "has-many": + t.addRelation(t.hasManyRelation(field)) + default: + panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName)) + } +} + +func (t *Table) addRelation(rel *Relation) { + if t.Relations == nil { + t.Relations = make(map[string]*Relation) + } + _, ok := t.Relations[rel.Field.GoName] + if ok { + panic(fmt.Errorf("%s already has %s", t, rel)) + } + t.Relations[rel.Field.GoName] = rel +} + +func (t *Table) belongsToRelation(field *Field) *Relation { + joinTable := t.dialect.Tables().Ref(field.IndirectType) + if err := joinTable.CheckPKs(); err != nil { + panic(err) + } + + rel := &Relation{ + Type: HasOneRelation, + Field: field, + JoinTable: joinTable, + } + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + joinColumn := joinColumns[i] + + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + } + return rel + } + + rel.JoinFields = joinTable.PKs + fkPrefix := internal.Underscore(field.GoName) + "_" + for _, joinPK := range joinTable.PKs { + fkName := fkPrefix + joinPK.Name + if fk := t.fieldWithLock(fkName); fk != nil { + rel.BaseFields = append(rel.BaseFields, fk) + continue + } + + if fk := t.fieldWithLock(joinPK.Name); fk != nil { + rel.BaseFields = append(rel.BaseFields, fk) + continue + } + + panic(fmt.Errorf( + "bun: %s belongs-to %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on %s field)", + t.TypeName, field.GoName, t.TypeName, fkName, field.GoName, + )) + } + return rel +} + +func (t *Table) hasOneRelation(field *Field) *Relation { + if err := t.CheckPKs(); err != nil { + panic(err) + } + + joinTable := t.dialect.Tables().Ref(field.IndirectType) + rel := &Relation{ + Type: BelongsToRelation, + Field: field, + JoinTable: joinTable, + } + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + field.GoName, t.TypeName, joinTable.TypeName, baseColumn, + )) + } + + joinColumn := joinColumns[i] + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + field.GoName, t.TypeName, joinTable.TypeName, baseColumn, + )) + } + } + return rel + } + + rel.BaseFields = t.PKs + fkPrefix := internal.Underscore(t.ModelName) + "_" + for _, pk := range t.PKs { + fkName := fkPrefix + pk.Name + if f := joinTable.fieldWithLock(fkName); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + continue + } + + if f := joinTable.fieldWithLock(pk.Name); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + continue + } + + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on %s field)", + field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName, + )) + } + return rel +} + +func (t *Table) hasManyRelation(field *Field) *Relation { + if err := t.CheckPKs(); err != nil { + panic(err) + } + if field.IndirectType.Kind() != reflect.Slice { + panic(fmt.Errorf( + "bun: %s.%s has-many relation requires slice, got %q", + t.TypeName, field.GoName, field.IndirectType.Kind(), + )) + } + + joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) + polymorphicValue, isPolymorphic := field.Tag.Options["polymorphic"] + rel := &Relation{ + Type: HasManyRelation, + Field: field, + JoinTable: joinTable, + } + var polymorphicColumn string + + if join, ok := field.Tag.Options["join"]; ok { + baseColumns, joinColumns := parseRelationJoin(join) + for i, baseColumn := range baseColumns { + joinColumn := joinColumns[i] + + if isPolymorphic && baseColumn == "type" { + polymorphicColumn = joinColumn + continue + } + + if f := t.fieldWithLock(baseColumn); f != nil { + rel.BaseFields = append(rel.BaseFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + + if f := joinTable.fieldWithLock(joinColumn); f != nil { + rel.JoinFields = append(rel.JoinFields, f) + } else { + panic(fmt.Errorf( + "bun: %s has-one %s: %s must have column %s", + t.TypeName, field.GoName, t.TypeName, baseColumn, + )) + } + } + } else { + rel.BaseFields = t.PKs + fkPrefix := internal.Underscore(t.ModelName) + "_" + if isPolymorphic { + polymorphicColumn = fkPrefix + "type" + } + + for _, pk := range t.PKs { + joinColumn := fkPrefix + pk.Name + if fk := joinTable.fieldWithLock(joinColumn); fk != nil { + rel.JoinFields = append(rel.JoinFields, fk) + continue + } + + if fk := joinTable.fieldWithLock(pk.Name); fk != nil { + rel.JoinFields = append(rel.JoinFields, fk) + continue + } + + panic(fmt.Errorf( + "bun: %s has-many %s: %s must have column %s "+ + "(to override, use join:base_column=join_column tag on the field %s)", + t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName, + )) + } + } + + if isPolymorphic { + rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn) + if rel.PolymorphicField == nil { + panic(fmt.Errorf( + "bun: %s has-many %s: %s must have polymorphic column %s", + t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn, + )) + } + + if polymorphicValue == "" { + polymorphicValue = t.ModelName + } + rel.PolymorphicValue = polymorphicValue + } + + return rel +} + +func (t *Table) m2mRelation(field *Field) *Relation { + if field.IndirectType.Kind() != reflect.Slice { + panic(fmt.Errorf( + "bun: %s.%s m2m relation requires slice, got %q", + t.TypeName, field.GoName, field.IndirectType.Kind(), + )) + } + joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) + + if err := t.CheckPKs(); err != nil { + panic(err) + } + if err := joinTable.CheckPKs(); err != nil { + panic(err) + } + + m2mTableName, ok := field.Tag.Options["m2m"] + if !ok { + panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName)) + } + + m2mTable := t.dialect.Tables().ByName(m2mTableName) + if m2mTable == nil { + panic(fmt.Errorf( + "bun: can't find m2m %s table (use db.RegisterModel)", + m2mTableName, + )) + } + + rel := &Relation{ + Type: ManyToManyRelation, + Field: field, + JoinTable: joinTable, + M2MTable: m2mTable, + } + var leftColumn, rightColumn string + + if join, ok := field.Tag.Options["join"]; ok { + left, right := parseRelationJoin(join) + leftColumn = left[0] + rightColumn = right[0] + } else { + leftColumn = t.TypeName + rightColumn = joinTable.TypeName + } + + leftField := m2mTable.fieldByGoName(leftColumn) + if leftField == nil { + panic(fmt.Errorf( + "bun: %s many-to-many %s: %s must have field %s "+ + "(to override, use tag join:LeftField=RightField on field %s.%s", + t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName, + )) + } + + rightField := m2mTable.fieldByGoName(rightColumn) + if rightField == nil { + panic(fmt.Errorf( + "bun: %s many-to-many %s: %s must have field %s "+ + "(to override, use tag join:LeftField=RightField on field %s.%s", + t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName, + )) + } + + leftRel := m2mTable.belongsToRelation(leftField) + rel.BaseFields = leftRel.JoinFields + rel.M2MBaseFields = leftRel.BaseFields + + rightRel := m2mTable.belongsToRelation(rightField) + rel.JoinFields = rightRel.JoinFields + rel.M2MJoinFields = rightRel.BaseFields + + return rel +} + +func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) { + if path == nil { + path = map[reflect.Type]struct{}{ + t.Type: {}, + } + } + + if _, ok := path[field.IndirectType]; ok { + return + } + path[field.IndirectType] = struct{}{} + + joinTable := t.dialect.Tables().Ref(field.IndirectType) + for _, f := range joinTable.allFields { + f = f.Clone() + f.GoName = field.GoName + "_" + f.GoName + f.Name = field.Name + "__" + f.Name + f.SQLName = t.quoteIdent(f.Name) + f.Index = appendNew(field.Index, f.Index...) + + t.fieldsMapMu.Lock() + if _, ok := t.FieldMap[f.Name]; !ok { + t.FieldMap[f.Name] = f + } + t.fieldsMapMu.Unlock() + + if f.IndirectType.Kind() != reflect.Struct { + continue + } + + if _, ok := path[f.IndirectType]; !ok { + t.inlineFields(f, path) + } + } +} + +//------------------------------------------------------------------------------ + +func (t *Table) Dialect() Dialect { return t.dialect } + +//------------------------------------------------------------------------------ + +func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) } +func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) } + +//------------------------------------------------------------------------------ + +func (t *Table) AppendNamedArg( + fmter Formatter, b []byte, name string, strct reflect.Value, +) ([]byte, bool) { + if field, ok := t.FieldMap[name]; ok { + return fmter.appendArg(b, field.Value(strct).Interface()), true + } + return b, false +} + +func (t *Table) quoteTableName(s string) Safe { + // Don't quote if table name contains placeholder (?) or parentheses. + if strings.IndexByte(s, '?') >= 0 || + strings.IndexByte(s, '(') >= 0 || + strings.IndexByte(s, ')') >= 0 { + return Safe(s) + } + return t.quoteIdent(s) +} + +func (t *Table) quoteIdent(s string) Safe { + return Safe(NewFormatter(t.dialect).AppendIdent(nil, s)) +} + +func appendNew(dst []int, src ...int) []int { + cp := make([]int, len(dst)+len(src)) + copy(cp, dst) + copy(cp[len(dst):], src) + return cp +} + +func isKnownTableOption(name string) bool { + switch name { + case "alias", "select": + return true + } + return false +} + +func isKnownFieldOption(name string) bool { + switch name { + case "alias", + "type", + "array", + "hstore", + "composite", + "json_use_number", + "msgpack", + "notnull", + "nullzero", + "allowzero", + "default", + "unique", + "soft_delete", + + "pk", + "autoincrement", + "rel", + "join", + "m2m", + "polymorphic": + return true + } + return false +} + +func removeField(fields []*Field, field *Field) []*Field { + for i, f := range fields { + if f == field { + return append(fields[:i], fields[i+1:]...) + } + } + return fields +} + +func parseRelationJoin(join string) ([]string, []string) { + ss := strings.Split(join, ",") + baseColumns := make([]string, len(ss)) + joinColumns := make([]string, len(ss)) + for i, s := range ss { + ss := strings.Split(strings.TrimSpace(s), "=") + if len(ss) != 2 { + panic(fmt.Errorf("can't parse relation join: %q", join)) + } + baseColumns[i] = ss[0] + joinColumns[i] = ss[1] + } + return baseColumns, joinColumns +} + +//------------------------------------------------------------------------------ + +func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error { + typ := field.StructField.Type + + switch typ { + case timeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*time.Time) + *ptr = time.Now() + return nil + } + case nullTimeType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*sql.NullTime) + *ptr = sql.NullTime{Time: time.Now()} + return nil + } + case nullIntType: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*sql.NullInt64) + *ptr = sql.NullInt64{Int64: time.Now().UnixNano()} + return nil + } + } + + switch field.IndirectType.Kind() { + case reflect.Int64: + return func(fv reflect.Value) error { + ptr := fv.Addr().Interface().(*int64) + *ptr = time.Now().UnixNano() + return nil + } + case reflect.Ptr: + typ = typ.Elem() + default: + return softDeleteFieldUpdaterFallback(field) + } + + switch typ { //nolint:gocritic + case timeType: + return func(fv reflect.Value) error { + now := time.Now() + fv.Set(reflect.ValueOf(&now)) + return nil + } + } + + switch typ.Kind() { //nolint:gocritic + case reflect.Int64: + return func(fv reflect.Value) error { + utime := time.Now().UnixNano() + fv.Set(reflect.ValueOf(&utime)) + return nil + } + } + + return softDeleteFieldUpdaterFallback(field) +} + +func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error { + return func(fv reflect.Value) error { + return field.ScanWithCheck(fv, time.Now()) + } +} diff --git a/vendor/github.com/go-pg/pg/v10/orm/tables.go b/vendor/github.com/uptrace/bun/schema/tables.go index fa937a54e..d82d08f59 100644 --- a/vendor/github.com/go-pg/pg/v10/orm/tables.go +++ b/vendor/github.com/uptrace/bun/schema/tables.go @@ -1,15 +1,11 @@ -package orm +package schema import ( "fmt" "reflect" "sync" - - "github.com/go-pg/pg/v10/types" ) -var _tables = newTables() - type tableInProgress struct { table *Table @@ -41,40 +37,36 @@ func (inp *tableInProgress) init2() bool { return inited } -// GetTable returns a Table for a struct type. -func GetTable(typ reflect.Type) *Table { - return _tables.Get(typ) -} - -// RegisterTable registers a struct as SQL table. -// It is usually used to register intermediate table -// in many to many relationship. -func RegisterTable(strct interface{}) { - _tables.Register(strct) -} - -type tables struct { - tables sync.Map +type Tables struct { + dialect Dialect + tables sync.Map mu sync.RWMutex inProgress map[reflect.Type]*tableInProgress } -func newTables() *tables { - return &tables{ +func NewTables(dialect Dialect) *Tables { + return &Tables{ + dialect: dialect, inProgress: make(map[reflect.Type]*tableInProgress), } } -func (t *tables) Register(strct interface{}) { - typ := reflect.TypeOf(strct) - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() +func (t *Tables) Register(models ...interface{}) { + for _, model := range models { + _ = t.Get(reflect.TypeOf(model).Elem()) } - _ = t.Get(typ) } -func (t *tables) get(typ reflect.Type, allowInProgress bool) *Table { +func (t *Tables) Get(typ reflect.Type) *Table { + return t.table(typ, false) +} + +func (t *Tables) Ref(typ reflect.Type) *Table { + return t.table(typ, true) +} + +func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { if typ.Kind() != reflect.Struct { panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) } @@ -94,7 +86,7 @@ func (t *tables) get(typ reflect.Type, allowInProgress bool) *Table { inProgress := t.inProgress[typ] if inProgress == nil { - table = newTable(typ) + table = newTable(t.dialect, typ) inProgress = newTableInProgress(table) t.inProgress[typ] = inProgress } else { @@ -115,18 +107,38 @@ func (t *tables) get(typ reflect.Type, allowInProgress bool) *Table { t.mu.Unlock() } + t.dialect.OnTable(table) + + for _, field := range table.FieldMap { + if field.UserSQLType == "" { + field.UserSQLType = field.DiscoveredSQLType + } + if field.CreateTableSQLType == "" { + field.CreateTableSQLType = field.UserSQLType + } + } + return table } -func (t *tables) Get(typ reflect.Type) *Table { - return t.get(typ, false) +func (t *Tables) ByModel(name string) *Table { + var found *Table + t.tables.Range(func(key, value interface{}) bool { + t := value.(*Table) + if t.TypeName == name { + found = t + return false + } + return true + }) + return found } -func (t *tables) getByName(name types.Safe) *Table { +func (t *Tables) ByName(name string) *Table { var found *Table t.tables.Range(func(key, value interface{}) bool { t := value.(*Table) - if t.SQLName == name { + if t.Name == name { found = t return false } diff --git a/vendor/github.com/uptrace/bun/schema/util.go b/vendor/github.com/uptrace/bun/schema/util.go new file mode 100644 index 000000000..6d474e4cc --- /dev/null +++ b/vendor/github.com/uptrace/bun/schema/util.go @@ -0,0 +1,53 @@ +package schema + +import "reflect" + +func indirectType(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) { + if len(index) == 1 { + return v.Field(index[0]), true + } + + for i, idx := range index { + if i > 0 { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return v, false + } + v = v.Elem() + } + } + v = v.Field(idx) + } + return v, true +} + +func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return v.Field(index[0]) + } + + for i, idx := range index { + if i > 0 { + v = indirectNil(v) + } + v = v.Field(idx) + } + return v +} + +func indirectNil(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} diff --git a/vendor/github.com/go-pg/zerochecker/zerochecker.go b/vendor/github.com/uptrace/bun/schema/zerochecker.go index 61bd207c9..95efeee6b 100644 --- a/vendor/github.com/go-pg/zerochecker/zerochecker.go +++ b/vendor/github.com/uptrace/bun/schema/zerochecker.go @@ -1,30 +1,37 @@ -package zerochecker +package schema import ( "database/sql/driver" "reflect" ) -var driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() -var appenderType = reflect.TypeOf((*valueAppender)(nil)).Elem() var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem() type isZeroer interface { IsZero() bool } -type valueAppender interface { - AppendValue(b []byte, flags int) ([]byte, error) -} +type IsZeroerFunc func(reflect.Value) bool -type Func func(reflect.Value) bool +func FieldZeroChecker(field *Field) IsZeroerFunc { + return zeroChecker(field.IndirectType) +} -func Checker(typ reflect.Type) Func { +func zeroChecker(typ reflect.Type) IsZeroerFunc { if typ.Implements(isZeroerType) { return isZeroInterface } - switch typ.Kind() { + kind := typ.Kind() + + if kind != reflect.Ptr { + ptr := reflect.PtrTo(typ) + if ptr.Implements(isZeroerType) { + return addrChecker(isZeroInterface) + } + } + + switch kind { case reflect.Array: if typ.Elem().Kind() == reflect.Uint8 { return isZeroBytes @@ -44,14 +51,20 @@ func Checker(typ reflect.Type) Func { return isNil } - if typ.Implements(appenderType) { - return isZeroAppenderValue - } if typ.Implements(driverValuerType) { return isZeroDriverValue } - return isZeroFalse + return notZero +} + +func addrChecker(fn IsZeroerFunc) IsZeroerFunc { + return func(v reflect.Value) bool { + if !v.CanAddr() { + return false + } + return fn(v.Addr()) + } } func isZeroInterface(v reflect.Value) bool { @@ -61,19 +74,6 @@ func isZeroInterface(v reflect.Value) bool { return v.Interface().(isZeroer).IsZero() } -func isZeroAppenderValue(v reflect.Value) bool { - if v.Kind() == reflect.Ptr { - return v.IsNil() - } - - appender := v.Interface().(valueAppender) - value, err := appender.AppendValue(nil, 0) - if err != nil { - return false - } - return value == nil -} - func isZeroDriverValue(v reflect.Value) bool { if v.Kind() == reflect.Ptr { return v.IsNil() @@ -121,6 +121,6 @@ func isZeroBytes(v reflect.Value) bool { return true } -func isZeroFalse(v reflect.Value) bool { +func notZero(v reflect.Value) bool { return false } diff --git a/vendor/github.com/uptrace/bun/util.go b/vendor/github.com/uptrace/bun/util.go new file mode 100644 index 000000000..ce56be805 --- /dev/null +++ b/vendor/github.com/uptrace/bun/util.go @@ -0,0 +1,114 @@ +package bun + +import "reflect" + +func indirect(v reflect.Value) reflect.Value { + switch v.Kind() { + case reflect.Interface: + return indirect(v.Elem()) + case reflect.Ptr: + return v.Elem() + default: + return v + } +} + +func walk(v reflect.Value, index []int, fn func(reflect.Value)) { + v = reflect.Indirect(v) + switch v.Kind() { + case reflect.Slice: + sliceLen := v.Len() + for i := 0; i < sliceLen; i++ { + visitField(v.Index(i), index, fn) + } + default: + visitField(v, index, fn) + } +} + +func visitField(v reflect.Value, index []int, fn func(reflect.Value)) { + v = reflect.Indirect(v) + if len(index) > 0 { + v = v.Field(index[0]) + if v.Kind() == reflect.Ptr && v.IsNil() { + return + } + walk(v, index[1:], fn) + } else { + fn(v) + } +} + +func typeByIndex(t reflect.Type, index []int) reflect.Type { + for _, x := range index { + switch t.Kind() { + case reflect.Ptr: + t = t.Elem() + case reflect.Slice: + t = indirectType(t.Elem()) + } + t = t.Field(x).Type + } + return indirectType(t) +} + +func indirectType(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func sliceElemType(v reflect.Value) reflect.Type { + elemType := v.Type().Elem() + if elemType.Kind() == reflect.Interface && v.Len() > 0 { + return indirect(v.Index(0).Elem()).Type() + } + return indirectType(elemType) +} + +func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value { + if v.Kind() == reflect.Array { + var pos int + return func() reflect.Value { + v := v.Index(pos) + pos++ + return v + } + } + + sliceType := v.Type() + elemType := sliceType.Elem() + + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + return func() reflect.Value { + if v.Len() < v.Cap() { + v.Set(v.Slice(0, v.Len()+1)) + elem := v.Index(v.Len() - 1) + if elem.IsNil() { + elem.Set(reflect.New(elemType)) + } + return elem.Elem() + } + + elem := reflect.New(elemType) + v.Set(reflect.Append(v, elem)) + return elem.Elem() + } + } + + zero := reflect.Zero(elemType) + return func() reflect.Value { + l := v.Len() + c := v.Cap() + + if l < c { + v.Set(v.Slice(0, l+1)) + return v.Index(l) + } + + v.Set(reflect.Append(v, zero)) + return v.Index(l) + } +} diff --git a/vendor/github.com/uptrace/bun/version.go b/vendor/github.com/uptrace/bun/version.go new file mode 100644 index 000000000..1baf9a39c --- /dev/null +++ b/vendor/github.com/uptrace/bun/version.go @@ -0,0 +1,6 @@ +package bun + +// Version is the current release version. +func Version() string { + return "0.4.3" +} diff --git a/vendor/github.com/vmihailenco/bufpool/.travis.yml b/vendor/github.com/vmihailenco/bufpool/.travis.yml deleted file mode 100644 index c7383a2b1..000000000 --- a/vendor/github.com/vmihailenco/bufpool/.travis.yml +++ /dev/null @@ -1,20 +0,0 @@ -sudo: false -language: go - -go: - - 1.11.x - - 1.12.x - - 1.13.x - - tip - -matrix: - allow_failures: - - go: tip - -env: - - GO111MODULE=on - -go_import_path: github.com/vmihailenco/bufpool - -before_install: - - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.21.0 diff --git a/vendor/github.com/vmihailenco/bufpool/LICENSE b/vendor/github.com/vmihailenco/bufpool/LICENSE deleted file mode 100644 index 2b76a892e..000000000 --- a/vendor/github.com/vmihailenco/bufpool/LICENSE +++ /dev/null @@ -1,23 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2014 Juan Batiz-Benet -Copyright (c) 2016 Aliaksandr Valialkin, VertaMedia -Copyright (c) 2019 Vladimir Mihailenco - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vendor/github.com/vmihailenco/bufpool/Makefile b/vendor/github.com/vmihailenco/bufpool/Makefile deleted file mode 100644 index 57914e333..000000000 --- a/vendor/github.com/vmihailenco/bufpool/Makefile +++ /dev/null @@ -1,6 +0,0 @@ -all: - go test ./... - go test ./... -short -race - go test ./... -run=NONE -bench=. -benchmem - env GOOS=linux GOARCH=386 go test ./... - golangci-lint run diff --git a/vendor/github.com/vmihailenco/bufpool/README.md b/vendor/github.com/vmihailenco/bufpool/README.md deleted file mode 100644 index 05a70791c..000000000 --- a/vendor/github.com/vmihailenco/bufpool/README.md +++ /dev/null @@ -1,74 +0,0 @@ -# bufpool - -[](https://travis-ci.org/vmihailenco/bufpool) -[](https://godoc.org/github.com/vmihailenco/bufpool) - -bufpool is an implementation of a pool of byte buffers with anti-memory-waste protection. It is based on the code and ideas from these 2 projects: -- https://github.com/libp2p/go-buffer-pool -- https://github.com/valyala/bytebufferpool - -bufpool consists of global pool of buffers that have a capacity of a power of 2 starting from 64 bytes to 32 megabytes. It also provides individual pools that maintain usage stats to provide buffers of the size that satisfies 95% of the calls. Global pool is used to reuse buffers between different parts of the app. - -# Installation - -``` go -go get github.com/vmihailenco/bufpool -``` - -# Usage - -bufpool can be used as a replacement for `sync.Pool`: - -``` go -var jsonPool bufpool.Pool // basically sync.Pool with usage stats - -func writeJSON(w io.Writer, obj interface{}) error { - buf := jsonPool.Get() - defer jsonPool.Put(buf) - - if err := json.NewEncoder(buf).Encode(obj); err != nil { - return err - } - - _, err := w.Write(buf.Bytes()) - return err -} -``` - -or to allocate buffer of the given size: - -``` go -func writeHex(w io.Writer, data []byte) error { - n := hex.EncodedLen(len(data))) - - buf := bufpool.Get(n) // buf.Len() is guaranteed to equal n - defer bufpool.Put(buf) - - tmp := buf.Bytes() - hex.Encode(tmp, data) - - _, err := w.Write(tmp) - return err -} -``` - -If you need to append data to the buffer you can use following pattern: - -``` go -buf := bufpool.Get(n) -defer bufpool.Put(buf) - -bb := buf.Bytes()[:0] - -bb = append(bb, ...) - -buf.ResetBuf(bb) -``` - -You can also change default pool thresholds: - -``` go -var jsonPool = bufpool.Pool{ - ServePctile: 0.95, // serve p95 buffers -} -``` diff --git a/vendor/github.com/vmihailenco/bufpool/buf_pool.go b/vendor/github.com/vmihailenco/bufpool/buf_pool.go deleted file mode 100644 index 2daa69888..000000000 --- a/vendor/github.com/vmihailenco/bufpool/buf_pool.go +++ /dev/null @@ -1,67 +0,0 @@ -package bufpool - -import ( - "log" - "sync" -) - -var thePool bufPool - -// Get retrieves a buffer of the appropriate length from the buffer pool or -// allocates a new one. Get may choose to ignore the pool and treat it as empty. -// Callers should not assume any relation between values passed to Put and the -// values returned by Get. -// -// If no suitable buffer exists in the pool, Get creates one. -func Get(length int) *Buffer { - return thePool.Get(length) -} - -// Put returns a buffer to the buffer pool. -func Put(buf *Buffer) { - thePool.Put(buf) -} - -type bufPool struct { - pools [steps]sync.Pool -} - -func (p *bufPool) Get(length int) *Buffer { - if length > maxPoolSize { - return NewBuffer(make([]byte, length)) - } - - idx := index(length) - if bufIface := p.pools[idx].Get(); bufIface != nil { - buf := bufIface.(*Buffer) - unlock(buf) - if length > buf.Cap() { - log.Println(idx, buf.Len(), buf.Cap(), buf.String()) - } - buf.buf = buf.buf[:length] - return buf - } - - b := make([]byte, length, indexSize(idx)) - return NewBuffer(b) -} - -func (p *bufPool) Put(buf *Buffer) { - length := buf.Cap() - if length > maxPoolSize || length < minSize { - return // drop it - } - - idx := prevIndex(length) - lock(buf) - p.pools[idx].Put(buf) -} - -func lock(buf *Buffer) { - buf.buf = buf.buf[:cap(buf.buf)] - buf.off = cap(buf.buf) + 1 -} - -func unlock(buf *Buffer) { - buf.off = 0 -} diff --git a/vendor/github.com/vmihailenco/bufpool/buffer.go b/vendor/github.com/vmihailenco/bufpool/buffer.go deleted file mode 100644 index a061a0b70..000000000 --- a/vendor/github.com/vmihailenco/bufpool/buffer.go +++ /dev/null @@ -1,397 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package bufpool - -// Simple byte buffer for marshaling data. - -import ( - "bytes" - "errors" - "io" - "unicode/utf8" -) - -// smallBufferSize is an initial allocation minimal capacity. -const smallBufferSize = 64 - -// A Buffer is a variable-sized buffer of bytes with Read and Write methods. -// The zero value for Buffer is an empty buffer ready to use. -type Buffer struct { - buf []byte // contents are the bytes buf[off : len(buf)] - off int // read at &buf[off], write at &buf[len(buf)] - lastRead readOp // last read operation, so that Unread* can work correctly. -} - -// The readOp constants describe the last action performed on -// the buffer, so that UnreadRune and UnreadByte can check for -// invalid usage. opReadRuneX constants are chosen such that -// converted to int they correspond to the rune size that was read. -type readOp int8 - -// Don't use iota for these, as the values need to correspond with the -// names and comments, which is easier to see when being explicit. -const ( - opRead readOp = -1 // Any other read operation. - opInvalid readOp = 0 // Non-read operation. - opReadRune1 readOp = 1 // Read rune of size 1. -) - -var errNegativeRead = errors.New("bytes.Buffer: reader returned negative count from Read") - -const maxInt = int(^uint(0) >> 1) - -// Bytes returns a slice of length b.Len() holding the unread portion of the buffer. -// The slice is valid for use only until the next buffer modification (that is, -// only until the next call to a method like Read, Write, Reset, or Truncate). -// The slice aliases the buffer content at least until the next buffer modification, -// so immediate changes to the slice will affect the result of future reads. -func (b *Buffer) Bytes() []byte { return b.buf[b.off:] } - -// String returns the contents of the unread portion of the buffer -// as a string. If the Buffer is a nil pointer, it returns "<nil>". -// -// To build strings more efficiently, see the strings.Builder type. -func (b *Buffer) String() string { - if b == nil { - // Special case, useful in debugging. - return "<nil>" - } - return string(b.buf[b.off:]) -} - -// empty reports whether the unread portion of the buffer is empty. -func (b *Buffer) empty() bool { return len(b.buf) <= b.off } - -// Len returns the number of bytes of the unread portion of the buffer; -// b.Len() == len(b.Bytes()). -func (b *Buffer) Len() int { return len(b.buf) - b.off } - -// Cap returns the capacity of the buffer's underlying byte slice, that is, the -// total space allocated for the buffer's data. -func (b *Buffer) Cap() int { return cap(b.buf) } - -// Truncate discards all but the first n unread bytes from the buffer -// but continues to use the same allocated storage. -// It panics if n is negative or greater than the length of the buffer. -func (b *Buffer) Truncate(n int) { - if n == 0 { - b.Reset() - return - } - b.lastRead = opInvalid - if n < 0 || n > b.Len() { - panic("bytes.Buffer: truncation out of range") - } - b.buf = b.buf[:b.off+n] -} - -// tryGrowByReslice is a inlineable version of grow for the fast-case where the -// internal buffer only needs to be resliced. -// It returns the index where bytes should be written and whether it succeeded. -func (b *Buffer) tryGrowByReslice(n int) (int, bool) { - if l := len(b.buf); n <= cap(b.buf)-l { - b.buf = b.buf[:l+n] - return l, true - } - return 0, false -} - -// Grow grows the buffer's capacity, if necessary, to guarantee space for -// another n bytes. After Grow(n), at least n bytes can be written to the -// buffer without another allocation. -// If n is negative, Grow will panic. -// If the buffer can't grow it will panic with ErrTooLarge. -func (b *Buffer) Grow(n int) { - if n < 0 { - panic("bytes.Buffer.Grow: negative count") - } - m := b.grow(n) - b.buf = b.buf[:m] -} - -// Write appends the contents of p to the buffer, growing the buffer as -// needed. The return value n is the length of p; err is always nil. If the -// buffer becomes too large, Write will panic with ErrTooLarge. -func (b *Buffer) Write(p []byte) (n int, err error) { - b.lastRead = opInvalid - m, ok := b.tryGrowByReslice(len(p)) - if !ok { - m = b.grow(len(p)) - } - return copy(b.buf[m:], p), nil -} - -// WriteString appends the contents of s to the buffer, growing the buffer as -// needed. The return value n is the length of s; err is always nil. If the -// buffer becomes too large, WriteString will panic with ErrTooLarge. -func (b *Buffer) WriteString(s string) (n int, err error) { - b.lastRead = opInvalid - m, ok := b.tryGrowByReslice(len(s)) - if !ok { - m = b.grow(len(s)) - } - return copy(b.buf[m:], s), nil -} - -// MinRead is the minimum slice size passed to a Read call by -// Buffer.ReadFrom. As long as the Buffer has at least MinRead bytes beyond -// what is required to hold the contents of r, ReadFrom will not grow the -// underlying buffer. -const minRead = 512 - -// ReadFrom reads data from r until EOF and appends it to the buffer, growing -// the buffer as needed. The return value n is the number of bytes read. Any -// error except io.EOF encountered during the read is also returned. If the -// buffer becomes too large, ReadFrom will panic with ErrTooLarge. -func (b *Buffer) ReadFrom(r io.Reader) (n int64, err error) { - b.lastRead = opInvalid - for { - i := b.grow(minRead) - b.buf = b.buf[:i] - m, e := r.Read(b.buf[i:cap(b.buf)]) - if m < 0 { - panic(errNegativeRead) - } - - b.buf = b.buf[:i+m] - n += int64(m) - if e == io.EOF { - return n, nil // e is EOF, so return nil explicitly - } - if e != nil { - return n, e - } - } -} - -// WriteTo writes data to w until the buffer is drained or an error occurs. -// The return value n is the number of bytes written; it always fits into an -// int, but it is int64 to match the io.WriterTo interface. Any error -// encountered during the write is also returned. -func (b *Buffer) WriteTo(w io.Writer) (n int64, err error) { - b.lastRead = opInvalid - if nBytes := b.Len(); nBytes > 0 { - m, e := w.Write(b.buf[b.off:]) - if m > nBytes { - panic("bytes.Buffer.WriteTo: invalid Write count") - } - b.off += m - n = int64(m) - if e != nil { - return n, e - } - // all bytes should have been written, by definition of - // Write method in io.Writer - if m != nBytes { - return n, io.ErrShortWrite - } - } - // Buffer is now empty; reset. - b.Reset() - return n, nil -} - -// WriteByte appends the byte c to the buffer, growing the buffer as needed. -// The returned error is always nil, but is included to match bufio.Writer's -// WriteByte. If the buffer becomes too large, WriteByte will panic with -// ErrTooLarge. -func (b *Buffer) WriteByte(c byte) error { - b.lastRead = opInvalid - m, ok := b.tryGrowByReslice(1) - if !ok { - m = b.grow(1) - } - b.buf[m] = c - return nil -} - -// WriteRune appends the UTF-8 encoding of Unicode code point r to the -// buffer, returning its length and an error, which is always nil but is -// included to match bufio.Writer's WriteRune. The buffer is grown as needed; -// if it becomes too large, WriteRune will panic with ErrTooLarge. -func (b *Buffer) WriteRune(r rune) (n int, err error) { - if r < utf8.RuneSelf { - _ = b.WriteByte(byte(r)) - return 1, nil - } - b.lastRead = opInvalid - m, ok := b.tryGrowByReslice(utf8.UTFMax) - if !ok { - m = b.grow(utf8.UTFMax) - } - n = utf8.EncodeRune(b.buf[m:m+utf8.UTFMax], r) - b.buf = b.buf[:m+n] - return n, nil -} - -// Read reads the next len(p) bytes from the buffer or until the buffer -// is drained. The return value n is the number of bytes read. If the -// buffer has no data to return, err is io.EOF (unless len(p) is zero); -// otherwise it is nil. -func (b *Buffer) Read(p []byte) (n int, err error) { - b.lastRead = opInvalid - if b.empty() { - // Buffer is empty, reset to recover space. - b.Reset() - if len(p) == 0 { - return 0, nil - } - return 0, io.EOF - } - n = copy(p, b.buf[b.off:]) - b.off += n - if n > 0 { - b.lastRead = opRead - } - return n, nil -} - -// Next returns a slice containing the next n bytes from the buffer, -// advancing the buffer as if the bytes had been returned by Read. -// If there are fewer than n bytes in the buffer, Next returns the entire buffer. -// The slice is only valid until the next call to a read or write method. -func (b *Buffer) Next(n int) []byte { - b.lastRead = opInvalid - m := b.Len() - if n > m { - n = m - } - data := b.buf[b.off : b.off+n] - b.off += n - if n > 0 { - b.lastRead = opRead - } - return data -} - -// ReadByte reads and returns the next byte from the buffer. -// If no byte is available, it returns error io.EOF. -func (b *Buffer) ReadByte() (byte, error) { - if b.empty() { - // Buffer is empty, reset to recover space. - b.Reset() - return 0, io.EOF - } - c := b.buf[b.off] - b.off++ - b.lastRead = opRead - return c, nil -} - -// ReadRune reads and returns the next UTF-8-encoded -// Unicode code point from the buffer. -// If no bytes are available, the error returned is io.EOF. -// If the bytes are an erroneous UTF-8 encoding, it -// consumes one byte and returns U+FFFD, 1. -func (b *Buffer) ReadRune() (r rune, size int, err error) { - if b.empty() { - // Buffer is empty, reset to recover space. - b.Reset() - return 0, 0, io.EOF - } - c := b.buf[b.off] - if c < utf8.RuneSelf { - b.off++ - b.lastRead = opReadRune1 - return rune(c), 1, nil - } - r, n := utf8.DecodeRune(b.buf[b.off:]) - b.off += n - b.lastRead = readOp(n) - return r, n, nil -} - -// UnreadRune unreads the last rune returned by ReadRune. -// If the most recent read or write operation on the buffer was -// not a successful ReadRune, UnreadRune returns an error. (In this regard -// it is stricter than UnreadByte, which will unread the last byte -// from any read operation.) -func (b *Buffer) UnreadRune() error { - if b.lastRead <= opInvalid { - return errors.New("bytes.Buffer: UnreadRune: previous operation was not a successful ReadRune") - } - if b.off >= int(b.lastRead) { - b.off -= int(b.lastRead) - } - b.lastRead = opInvalid - return nil -} - -var errUnreadByte = errors.New("bytes.Buffer: UnreadByte: previous operation was not a successful read") - -// UnreadByte unreads the last byte returned by the most recent successful -// read operation that read at least one byte. If a write has happened since -// the last read, if the last read returned an error, or if the read read zero -// bytes, UnreadByte returns an error. -func (b *Buffer) UnreadByte() error { - if b.lastRead == opInvalid { - return errUnreadByte - } - b.lastRead = opInvalid - if b.off > 0 { - b.off-- - } - return nil -} - -// ReadBytes reads until the first occurrence of delim in the input, -// returning a slice containing the data up to and including the delimiter. -// If ReadBytes encounters an error before finding a delimiter, -// it returns the data read before the error and the error itself (often io.EOF). -// ReadBytes returns err != nil if and only if the returned data does not end in -// delim. -func (b *Buffer) ReadBytes(delim byte) (line []byte, err error) { - slice, err := b.readSlice(delim) - // return a copy of slice. The buffer's backing array may - // be overwritten by later calls. - line = append(line, slice...) - return line, err -} - -// readSlice is like ReadBytes but returns a reference to internal buffer data. -func (b *Buffer) readSlice(delim byte) (line []byte, err error) { - i := bytes.IndexByte(b.buf[b.off:], delim) - end := b.off + i + 1 - if i < 0 { - end = len(b.buf) - err = io.EOF - } - line = b.buf[b.off:end] - b.off = end - b.lastRead = opRead - return line, err -} - -// ReadString reads until the first occurrence of delim in the input, -// returning a string containing the data up to and including the delimiter. -// If ReadString encounters an error before finding a delimiter, -// it returns the data read before the error and the error itself (often io.EOF). -// ReadString returns err != nil if and only if the returned data does not end -// in delim. -func (b *Buffer) ReadString(delim byte) (line string, err error) { - slice, err := b.readSlice(delim) - return string(slice), err -} - -// NewBuffer creates and initializes a new Buffer using buf as its -// initial contents. The new Buffer takes ownership of buf, and the -// caller should not use buf after this call. NewBuffer is intended to -// prepare a Buffer to read existing data. It can also be used to set -// the initial size of the internal buffer for writing. To do that, -// buf should have the desired capacity but a length of zero. -// -// In most cases, new(Buffer) (or just declaring a Buffer variable) is -// sufficient to initialize a Buffer. -func NewBuffer(buf []byte) *Buffer { return &Buffer{buf: buf} } - -// NewBufferString creates and initializes a new Buffer using string s as its -// initial contents. It is intended to prepare a buffer to read an existing -// string. -// -// In most cases, new(Buffer) (or just declaring a Buffer variable) is -// sufficient to initialize a Buffer. -func NewBufferString(s string) *Buffer { - return &Buffer{buf: []byte(s)} -} diff --git a/vendor/github.com/vmihailenco/bufpool/buffer_ext.go b/vendor/github.com/vmihailenco/bufpool/buffer_ext.go deleted file mode 100644 index 8a904bc5c..000000000 --- a/vendor/github.com/vmihailenco/bufpool/buffer_ext.go +++ /dev/null @@ -1,66 +0,0 @@ -package bufpool - -import "bytes" - -// Reset resets the buffer to be empty, -// but it retains the underlying storage for use by future writes. -// Reset is the same as Truncate(0). -func (b *Buffer) Reset() { - if b.off > cap(b.buf) { - panic("Buffer is used after Put") - } - b.buf = b.buf[:0] - b.off = 0 - b.lastRead = opInvalid -} - -func (b *Buffer) ResetBuf(buf []byte) { - if b.off > cap(b.buf) { - panic("Buffer is used after Put") - } - b.buf = buf[:0] - b.off = 0 - b.lastRead = opInvalid -} - -// grow grows the buffer to guarantee space for n more bytes. -// It returns the index where bytes should be written. -// If the buffer can't grow it will panic with ErrTooLarge. -func (b *Buffer) grow(n int) int { - if b.off > cap(b.buf) { - panic("Buffer is used after Put") - } - m := b.Len() - // If buffer is empty, reset to recover space. - if m == 0 && b.off != 0 { - b.Reset() - } - // Try to grow by means of a reslice. - if i, ok := b.tryGrowByReslice(n); ok { - return i - } - if b.buf == nil && n <= smallBufferSize { - b.buf = make([]byte, n, smallBufferSize) - return 0 - } - c := cap(b.buf) - if n <= c/2-m { - // We can slide things down instead of allocating a new - // slice. We only need m+n <= c to slide, but - // we instead let capacity get twice as large so we - // don't spend all our time copying. - copy(b.buf, b.buf[b.off:]) - } else if c > maxInt-c-n { - panic(bytes.ErrTooLarge) - } else { - // Not enough space anywhere, we need to allocate. - tmp := Get(2*c + n) - copy(tmp.buf, b.buf[b.off:]) - b.buf, tmp.buf = tmp.buf, b.buf - Put(tmp) - } - // Restore b.off and len(b.buf). - b.off = 0 - b.buf = b.buf[:m+n] - return m -} diff --git a/vendor/github.com/vmihailenco/bufpool/go.mod b/vendor/github.com/vmihailenco/bufpool/go.mod deleted file mode 100644 index 7f3096ae4..000000000 --- a/vendor/github.com/vmihailenco/bufpool/go.mod +++ /dev/null @@ -1,9 +0,0 @@ -module github.com/vmihailenco/bufpool - -go 1.13 - -require ( - github.com/kr/pretty v0.1.0 // indirect - github.com/stretchr/testify v1.5.1 - gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect -) diff --git a/vendor/github.com/vmihailenco/bufpool/pool.go b/vendor/github.com/vmihailenco/bufpool/pool.go deleted file mode 100644 index 3e1676b48..000000000 --- a/vendor/github.com/vmihailenco/bufpool/pool.go +++ /dev/null @@ -1,148 +0,0 @@ -package bufpool - -import ( - "math/bits" - "sync/atomic" -) - -const ( - minBitSize = 6 // 2**6=64 is a CPU cache line size - steps = 20 - - minSize = 1 << minBitSize // 64 bytes - maxSize = 1 << (minBitSize + steps - 1) // 32 mb - maxPoolSize = maxSize << 1 // 64 mb - - defaultServePctile = 0.95 - calibrateCallsThreshold = 42000 - defaultSize = 4096 -) - -// Pool represents byte buffer pool. -// -// Different pools should be used for different usage patterns to achieve better -// performance and lower memory usage. -type Pool struct { - calls [steps]uint32 - calibrating uint32 - - ServePctile float64 // default is 0.95 - serveSize uint32 -} - -func (p *Pool) getServeSize() int { - size := atomic.LoadUint32(&p.serveSize) - if size > 0 { - return int(size) - } - - for i := 0; i < len(p.calls); i++ { - calls := atomic.LoadUint32(&p.calls[i]) - if calls > 10 { - size := indexSize(i) - atomic.CompareAndSwapUint32(&p.serveSize, 0, uint32(size)) - return size - } - } - - return defaultSize -} - -// Get returns an empty buffer from the pool. Returned buffer capacity -// is determined by accumulated usage stats and changes over time. -// -// The buffer may be returned to the pool using Put or retained for further -// usage. In latter case buffer length must be updated using UpdateLen. -func (p *Pool) Get() *Buffer { - buf := Get(p.getServeSize()) - buf.Reset() - return buf -} - -// New returns an empty buffer bypassing the pool. Returned buffer capacity -// is determined by accumulated usage stats and changes over time. -func (p *Pool) New() *Buffer { - return NewBuffer(make([]byte, 0, p.getServeSize())) -} - -// Put returns buffer to the pool. -func (p *Pool) Put(buf *Buffer) { - length := buf.Len() - if length == 0 { - length = buf.Cap() - } - - p.UpdateLen(length) - - // Always put buf to the pool. - Put(buf) -} - -// UpdateLen updates stats about buffer length. -func (p *Pool) UpdateLen(bufLen int) { - idx := index(bufLen) - if atomic.AddUint32(&p.calls[idx], 1) > calibrateCallsThreshold { - p.calibrate() - } -} - -func (p *Pool) calibrate() { - if !atomic.CompareAndSwapUint32(&p.calibrating, 0, 1) { - return - } - - var callSum uint64 - var calls [steps]uint32 - - for i := 0; i < len(p.calls); i++ { - n := atomic.SwapUint32(&p.calls[i], 0) - calls[i] = n - callSum += uint64(n) - } - - serveSum := uint64(float64(callSum) * p.getServePctile()) - var serveSize int - - callSum = 0 - for i, numCall := range &calls { - callSum += uint64(numCall) - - if serveSize == 0 && callSum >= serveSum { - serveSize = indexSize(i) - break - } - } - - atomic.StoreUint32(&p.serveSize, uint32(serveSize)) - atomic.StoreUint32(&p.calibrating, 0) -} - -func (p *Pool) getServePctile() float64 { - if p.ServePctile > 0 { - return p.ServePctile - } - return defaultServePctile -} - -func index(n int) int { - if n == 0 { - return 0 - } - idx := bits.Len32(uint32((n - 1) >> minBitSize)) - if idx >= steps { - idx = steps - 1 - } - return idx -} - -func prevIndex(n int) int { - next := index(n) - if next == 0 || n == indexSize(next) { - return next - } - return next - 1 -} - -func indexSize(idx int) int { - return minSize << uint(idx) -} diff --git a/vendor/github.com/vmihailenco/tagparser/.travis.yml b/vendor/github.com/vmihailenco/tagparser/.travis.yml deleted file mode 100644 index ec5384523..000000000 --- a/vendor/github.com/vmihailenco/tagparser/.travis.yml +++ /dev/null @@ -1,24 +0,0 @@ -dist: xenial -sudo: false -language: go - -go: - - 1.11.x - - 1.12.x - - tip - -matrix: - allow_failures: - - go: tip - -env: - - GO111MODULE=on - -go_import_path: github.com/vmihailenco/tagparser - -before_install: - - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.17.1 - -script: - - make - - golangci-lint run diff --git a/vendor/github.com/vmihailenco/tagparser/LICENSE b/vendor/github.com/vmihailenco/tagparser/LICENSE deleted file mode 100644 index 3fc93fdff..000000000 --- a/vendor/github.com/vmihailenco/tagparser/LICENSE +++ /dev/null @@ -1,25 +0,0 @@ -Copyright (c) 2019 The github.com/vmihailenco/tagparser Authors. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/vmihailenco/tagparser/Makefile b/vendor/github.com/vmihailenco/tagparser/Makefile deleted file mode 100644 index fe9dc5bdb..000000000 --- a/vendor/github.com/vmihailenco/tagparser/Makefile +++ /dev/null @@ -1,8 +0,0 @@ -all: - go test ./... - go test ./... -short -race - go test ./... -run=NONE -bench=. -benchmem - env GOOS=linux GOARCH=386 go test ./... - go vet ./... - go get github.com/gordonklaus/ineffassign - ineffassign . diff --git a/vendor/github.com/vmihailenco/tagparser/README.md b/vendor/github.com/vmihailenco/tagparser/README.md deleted file mode 100644 index 411aa5444..000000000 --- a/vendor/github.com/vmihailenco/tagparser/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# Opinionated Golang tag parser - -[](https://travis-ci.org/vmihailenco/tagparser) -[](https://godoc.org/github.com/vmihailenco/tagparser) - -## Installation - -Install: - -```shell -go get -u github.com/vmihailenco/tagparser -``` - -## Quickstart - -```go -func ExampleParse() { - tag := tagparser.Parse("some_name,key:value,key2:'complex value'") - fmt.Println(tag.Name) - fmt.Println(tag.Options) - // Output: some_name - // map[key:value key2:'complex value'] -} -``` diff --git a/vendor/github.com/vmihailenco/tagparser/go.mod b/vendor/github.com/vmihailenco/tagparser/go.mod deleted file mode 100644 index 961a46ddb..000000000 --- a/vendor/github.com/vmihailenco/tagparser/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module github.com/vmihailenco/tagparser - -go 1.13 diff --git a/vendor/github.com/vmihailenco/tagparser/internal/parser/parser.go b/vendor/github.com/vmihailenco/tagparser/internal/parser/parser.go deleted file mode 100644 index 2de1c6f7b..000000000 --- a/vendor/github.com/vmihailenco/tagparser/internal/parser/parser.go +++ /dev/null @@ -1,82 +0,0 @@ -package parser - -import ( - "bytes" - - "github.com/vmihailenco/tagparser/internal" -) - -type Parser struct { - b []byte - i int -} - -func New(b []byte) *Parser { - return &Parser{ - b: b, - } -} - -func NewString(s string) *Parser { - return New(internal.StringToBytes(s)) -} - -func (p *Parser) Bytes() []byte { - return p.b[p.i:] -} - -func (p *Parser) Valid() bool { - return p.i < len(p.b) -} - -func (p *Parser) Read() byte { - if p.Valid() { - c := p.b[p.i] - p.Advance() - return c - } - return 0 -} - -func (p *Parser) Peek() byte { - if p.Valid() { - return p.b[p.i] - } - return 0 -} - -func (p *Parser) Advance() { - p.i++ -} - -func (p *Parser) Skip(skip byte) bool { - if p.Peek() == skip { - p.Advance() - return true - } - return false -} - -func (p *Parser) SkipBytes(skip []byte) bool { - if len(skip) > len(p.b[p.i:]) { - return false - } - if !bytes.Equal(p.b[p.i:p.i+len(skip)], skip) { - return false - } - p.i += len(skip) - return true -} - -func (p *Parser) ReadSep(sep byte) ([]byte, bool) { - ind := bytes.IndexByte(p.b[p.i:], sep) - if ind == -1 { - b := p.b[p.i:] - p.i = len(p.b) - return b, false - } - - b := p.b[p.i : p.i+ind] - p.i += ind + 1 - return b, true -} diff --git a/vendor/github.com/vmihailenco/tagparser/internal/safe.go b/vendor/github.com/vmihailenco/tagparser/internal/safe.go deleted file mode 100644 index 870fe541f..000000000 --- a/vendor/github.com/vmihailenco/tagparser/internal/safe.go +++ /dev/null @@ -1,11 +0,0 @@ -// +build appengine - -package internal - -func BytesToString(b []byte) string { - return string(b) -} - -func StringToBytes(s string) []byte { - return []byte(s) -} diff --git a/vendor/github.com/vmihailenco/tagparser/internal/unsafe.go b/vendor/github.com/vmihailenco/tagparser/internal/unsafe.go deleted file mode 100644 index f8bc18d91..000000000 --- a/vendor/github.com/vmihailenco/tagparser/internal/unsafe.go +++ /dev/null @@ -1,22 +0,0 @@ -// +build !appengine - -package internal - -import ( - "unsafe" -) - -// BytesToString converts byte slice to string. -func BytesToString(b []byte) string { - return *(*string)(unsafe.Pointer(&b)) -} - -// StringToBytes converts string to byte slice. -func StringToBytes(s string) []byte { - return *(*[]byte)(unsafe.Pointer( - &struct { - string - Cap int - }{s, len(s)}, - )) -} diff --git a/vendor/github.com/vmihailenco/tagparser/tagparser.go b/vendor/github.com/vmihailenco/tagparser/tagparser.go deleted file mode 100644 index 431002aef..000000000 --- a/vendor/github.com/vmihailenco/tagparser/tagparser.go +++ /dev/null @@ -1,181 +0,0 @@ -package tagparser - -import ( - "strings" - - "github.com/vmihailenco/tagparser/internal/parser" -) - -type Tag struct { - Name string - Options map[string]string -} - -func (t *Tag) HasOption(name string) bool { - _, ok := t.Options[name] - return ok -} - -func Parse(s string) *Tag { - p := &tagParser{ - Parser: parser.NewString(s), - } - p.parseKey() - return &p.Tag -} - -type tagParser struct { - *parser.Parser - - Tag Tag - hasName bool - key string -} - -func (p *tagParser) setTagOption(key, value string) { - key = strings.TrimSpace(key) - value = strings.TrimSpace(value) - - if !p.hasName { - p.hasName = true - if key == "" { - p.Tag.Name = value - return - } - } - if p.Tag.Options == nil { - p.Tag.Options = make(map[string]string) - } - if key == "" { - p.Tag.Options[value] = "" - } else { - p.Tag.Options[key] = value - } -} - -func (p *tagParser) parseKey() { - p.key = "" - - var b []byte - for p.Valid() { - c := p.Read() - switch c { - case ',': - p.Skip(' ') - p.setTagOption("", string(b)) - p.parseKey() - return - case ':': - p.key = string(b) - p.parseValue() - return - case '\'': - p.parseQuotedValue() - return - default: - b = append(b, c) - } - } - - if len(b) > 0 { - p.setTagOption("", string(b)) - } -} - -func (p *tagParser) parseValue() { - const quote = '\'' - - c := p.Peek() - if c == quote { - p.Skip(quote) - p.parseQuotedValue() - return - } - - var b []byte - for p.Valid() { - c = p.Read() - switch c { - case '\\': - b = append(b, p.Read()) - case '(': - b = append(b, c) - b = p.readBrackets(b) - case ',': - p.Skip(' ') - p.setTagOption(p.key, string(b)) - p.parseKey() - return - default: - b = append(b, c) - } - } - p.setTagOption(p.key, string(b)) -} - -func (p *tagParser) readBrackets(b []byte) []byte { - var lvl int -loop: - for p.Valid() { - c := p.Read() - switch c { - case '\\': - b = append(b, p.Read()) - case '(': - b = append(b, c) - lvl++ - case ')': - b = append(b, c) - lvl-- - if lvl < 0 { - break loop - } - default: - b = append(b, c) - } - } - return b -} - -func (p *tagParser) parseQuotedValue() { - const quote = '\'' - - var b []byte - b = append(b, quote) - - for p.Valid() { - bb, ok := p.ReadSep(quote) - if !ok { - b = append(b, bb...) - break - } - - if len(bb) > 0 && bb[len(bb)-1] == '\\' { - b = append(b, bb[:len(bb)-1]...) - b = append(b, quote) - continue - } - - b = append(b, bb...) - b = append(b, quote) - break - } - - p.setTagOption(p.key, string(b)) - if p.Skip(',') { - p.Skip(' ') - } - p.parseKey() -} - -func Unquote(s string) (string, bool) { - const quote = '\'' - - if len(s) < 2 { - return s, false - } - if s[0] == quote && s[len(s)-1] == quote { - return s[1 : len(s)-1], true - } - return s, false -} |