summaryrefslogtreecommitdiff
path: root/vendor/codeberg.org/gruf/go-fastcopy/copy.go
blob: a9c11592732da3ac7e20b282c5ed5c8e16a7be9a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package fastcopy

import (
	"io"
	"sync"
	_ "unsafe" // link to io.errInvalidWrite.
)

var (
	// global pool instance.
	pool = CopyPool{size: 4096}

	//go:linkname errInvalidWrite io.errInvalidWrite
	errInvalidWrite error
)

// CopyPool provides a memory pool of byte
// buffers for io copies from readers to writers.
type CopyPool struct {
	size int
	pool sync.Pool
}

// See CopyPool.Buffer().
func Buffer(sz int) int {
	return pool.Buffer(sz)
}

// See CopyPool.CopyN().
func CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) {
	return pool.CopyN(dst, src, n)
}

// See CopyPool.Copy().
func Copy(dst io.Writer, src io.Reader) (int64, error) {
	return pool.Copy(dst, src)
}

// Buffer sets the pool buffer size to allocate. Returns current size.
// Note this is NOT atomically safe, please call BEFORE other calls to CopyPool.
func (cp *CopyPool) Buffer(sz int) int {
	if sz > 0 {
		// update size
		cp.size = sz
	} else if cp.size < 1 {
		// default size
		return 4096
	}
	return cp.size
}

// CopyN performs the same logic as io.CopyN(), with the difference
// being that the byte buffer is acquired from a memory pool.
func (cp *CopyPool) CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) {
	written, err := cp.Copy(dst, io.LimitReader(src, n))
	if written == n {
		return n, nil
	}
	if written < n && err == nil {
		// src stopped early; must have been EOF.
		err = io.EOF
	}
	return written, err
}

// Copy performs the same logic as io.Copy(), with the difference
// being that the byte buffer is acquired from a memory pool.
func (cp *CopyPool) Copy(dst io.Writer, src io.Reader) (int64, error) {
	// Prefer using io.WriterTo to do the copy (avoids alloc + copy)
	if wt, ok := src.(io.WriterTo); ok {
		return wt.WriteTo(dst)
	}

	// Prefer using io.ReaderFrom to do the copy.
	if rt, ok := dst.(io.ReaderFrom); ok {
		return rt.ReadFrom(src)
	}

	var buf []byte

	if b, ok := cp.pool.Get().(*[]byte); ok {
		// Acquired buf from pool
		buf = *b
	} else {
		// Allocate new buffer of size
		buf = make([]byte, cp.Buffer(0))
	}

	// Defer release to pool
	defer cp.pool.Put(&buf)

	var n int64
	for {
		// Perform next read into buf
		nr, err := src.Read(buf)
		if nr > 0 {
			// We error check AFTER checking
			// no. read bytes so incomplete
			// read still gets written up to nr.

			// Perform next write from buf
			nw, ew := dst.Write(buf[0:nr])

			// Check for valid write
			if nw < 0 || nr < nw {
				if ew == nil {
					ew = errInvalidWrite
				}
				return n, ew
			}

			// Incr total count
			n += int64(nw)

			// Check write error
			if ew != nil {
				return n, ew
			}

			// Check unequal read/writes
			if nr != nw {
				return n, io.ErrShortWrite
			}
		}

		// Return on err
		if err != nil {
			if err == io.EOF {
				err = nil // expected
			}
			return n, err
		}
	}
}