summaryrefslogtreecommitdiff
path: root/vendor/codeberg.org/gruf/go-fastcopy/copy.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/codeberg.org/gruf/go-fastcopy/copy.go')
-rw-r--r--vendor/codeberg.org/gruf/go-fastcopy/copy.go134
1 files changed, 134 insertions, 0 deletions
diff --git a/vendor/codeberg.org/gruf/go-fastcopy/copy.go b/vendor/codeberg.org/gruf/go-fastcopy/copy.go
new file mode 100644
index 000000000..4716b140f
--- /dev/null
+++ b/vendor/codeberg.org/gruf/go-fastcopy/copy.go
@@ -0,0 +1,134 @@
+package fastcopy
+
+import (
+ "io"
+ "sync"
+ _ "unsafe" // link to io.errInvalidWrite.
+)
+
+var (
+ // global pool instance.
+ pool = CopyPool{size: 4096}
+
+ //go:linkname errInvalidWrite io.errInvalidWrite
+ errInvalidWrite error
+)
+
+// CopyPool provides a memory pool of byte
+// buffers for io copies from readers to writers.
+type CopyPool struct {
+ size int
+ pool sync.Pool
+}
+
+// See CopyPool.Buffer().
+func Buffer(sz int) int {
+ return pool.Buffer(sz)
+}
+
+// See CopyPool.CopyN().
+func CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) {
+ return pool.CopyN(dst, src, n)
+}
+
+// See CopyPool.Copy().
+func Copy(dst io.Writer, src io.Reader) (int64, error) {
+ return pool.Copy(dst, src)
+}
+
+// Buffer sets the pool buffer size to allocate. Returns current size.
+// Note this is NOT atomically safe, please call BEFORE other calls to CopyPool.
+func (cp *CopyPool) Buffer(sz int) int {
+ if sz > 0 {
+ // update size
+ cp.size = sz
+ } else if cp.size < 1 {
+ // default size
+ return 4096
+ }
+ return cp.size
+}
+
+// CopyN performs the same logic as io.CopyN(), with the difference
+// being that the byte buffer is acquired from a memory pool.
+func (cp *CopyPool) CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) {
+ written, err := cp.Copy(dst, io.LimitReader(src, n))
+ if written == n {
+ return n, nil
+ }
+ if written < n && err == nil {
+ // src stopped early; must have been EOF.
+ err = io.EOF
+ }
+ return written, err
+}
+
+// Copy performs the same logic as io.Copy(), with the difference
+// being that the byte buffer is acquired from a memory pool.
+func (cp *CopyPool) Copy(dst io.Writer, src io.Reader) (int64, error) {
+ // Prefer using io.WriterTo to do the copy (avoids alloc + copy)
+ if wt, ok := src.(io.WriterTo); ok {
+ return wt.WriteTo(dst)
+ }
+
+ // Prefer using io.ReaderFrom to do the copy.
+ if rt, ok := dst.(io.ReaderFrom); ok {
+ return rt.ReadFrom(src)
+ }
+
+ var buf []byte
+
+ if b, ok := cp.pool.Get().([]byte); ok {
+ // Acquired buf from pool
+ buf = b
+ } else {
+ // Allocate new buffer of size
+ buf = make([]byte, cp.Buffer(0))
+ }
+
+ // Defer release to pool
+ defer cp.pool.Put(buf)
+
+ var n int64
+ for {
+ // Perform next read into buf
+ nr, err := src.Read(buf)
+ if nr > 0 {
+ // We error check AFTER checking
+ // no. read bytes so incomplete
+ // read still gets written up to nr.
+
+ // Perform next write from buf
+ nw, ew := dst.Write(buf[0:nr])
+
+ // Check for valid write
+ if nw < 0 || nr < nw {
+ if ew == nil {
+ ew = errInvalidWrite
+ }
+ return n, ew
+ }
+
+ // Incr total count
+ n += int64(nw)
+
+ // Check write error
+ if ew != nil {
+ return n, ew
+ }
+
+ // Check unequal read/writes
+ if nr != nw {
+ return n, io.ErrShortWrite
+ }
+ }
+
+ // Return on err
+ if err != nil {
+ if err == io.EOF {
+ err = nil // expected
+ }
+ return n, err
+ }
+ }
+}