diff options
Diffstat (limited to 'vendor/google.golang.org/appengine/internal/api.go')
-rw-r--r-- | vendor/google.golang.org/appengine/internal/api.go | 347 |
1 files changed, 161 insertions, 186 deletions
diff --git a/vendor/google.golang.org/appengine/internal/api.go b/vendor/google.golang.org/appengine/internal/api.go index 721053c20..0569f5dd4 100644 --- a/vendor/google.golang.org/appengine/internal/api.go +++ b/vendor/google.golang.org/appengine/internal/api.go @@ -2,12 +2,14 @@ // Use of this source code is governed by the Apache 2.0 // license that can be found in the LICENSE file. +//go:build !appengine // +build !appengine package internal import ( "bytes" + "context" "errors" "fmt" "io/ioutil" @@ -24,7 +26,6 @@ import ( "time" "github.com/golang/protobuf/proto" - netcontext "golang.org/x/net/context" basepb "google.golang.org/appengine/internal/base" logpb "google.golang.org/appengine/internal/log" @@ -32,8 +33,7 @@ import ( ) const ( - apiPath = "/rpc_http" - defaultTicketSuffix = "/default.20150612t184001.0" + apiPath = "/rpc_http" ) var ( @@ -65,21 +65,22 @@ var ( IdleConnTimeout: 90 * time.Second, }, } - - defaultTicketOnce sync.Once - defaultTicket string - backgroundContextOnce sync.Once - backgroundContext netcontext.Context ) -func apiURL() *url.URL { +func apiURL(ctx context.Context) *url.URL { host, port := "appengine.googleapis.internal", "10001" if h := os.Getenv("API_HOST"); h != "" { host = h } + if hostOverride := ctx.Value(apiHostOverrideKey); hostOverride != nil { + host = hostOverride.(string) + } if p := os.Getenv("API_PORT"); p != "" { port = p } + if portOverride := ctx.Value(apiPortOverrideKey); portOverride != nil { + port = portOverride.(string) + } return &url.URL{ Scheme: "http", Host: host + ":" + port, @@ -87,82 +88,97 @@ func apiURL() *url.URL { } } -func handleHTTP(w http.ResponseWriter, r *http.Request) { - c := &context{ - req: r, - outHeader: w.Header(), - apiURL: apiURL(), - } - r = r.WithContext(withContext(r.Context(), c)) - c.req = r - - stopFlushing := make(chan int) +// Middleware wraps an http handler so that it can make GAE API calls +func Middleware(next http.Handler) http.Handler { + return handleHTTPMiddleware(executeRequestSafelyMiddleware(next)) +} - // Patch up RemoteAddr so it looks reasonable. - if addr := r.Header.Get(userIPHeader); addr != "" { - r.RemoteAddr = addr - } else if addr = r.Header.Get(remoteAddrHeader); addr != "" { - r.RemoteAddr = addr - } else { - // Should not normally reach here, but pick a sensible default anyway. - r.RemoteAddr = "127.0.0.1" - } - // The address in the headers will most likely be of these forms: - // 123.123.123.123 - // 2001:db8::1 - // net/http.Request.RemoteAddr is specified to be in "IP:port" form. - if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { - // Assume the remote address is only a host; add a default port. - r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80") - } +func handleHTTPMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c := &aeContext{ + req: r, + outHeader: w.Header(), + } + r = r.WithContext(withContext(r.Context(), c)) + c.req = r + + stopFlushing := make(chan int) + + // Patch up RemoteAddr so it looks reasonable. + if addr := r.Header.Get(userIPHeader); addr != "" { + r.RemoteAddr = addr + } else if addr = r.Header.Get(remoteAddrHeader); addr != "" { + r.RemoteAddr = addr + } else { + // Should not normally reach here, but pick a sensible default anyway. + r.RemoteAddr = "127.0.0.1" + } + // The address in the headers will most likely be of these forms: + // 123.123.123.123 + // 2001:db8::1 + // net/http.Request.RemoteAddr is specified to be in "IP:port" form. + if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { + // Assume the remote address is only a host; add a default port. + r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80") + } - // Start goroutine responsible for flushing app logs. - // This is done after adding c to ctx.m (and stopped before removing it) - // because flushing logs requires making an API call. - go c.logFlusher(stopFlushing) + if logToLogservice() { + // Start goroutine responsible for flushing app logs. + // This is done after adding c to ctx.m (and stopped before removing it) + // because flushing logs requires making an API call. + go c.logFlusher(stopFlushing) + } - executeRequestSafely(c, r) - c.outHeader = nil // make sure header changes aren't respected any more + next.ServeHTTP(c, r) + c.outHeader = nil // make sure header changes aren't respected any more - stopFlushing <- 1 // any logging beyond this point will be dropped + flushed := make(chan struct{}) + if logToLogservice() { + stopFlushing <- 1 // any logging beyond this point will be dropped - // Flush any pending logs asynchronously. - c.pendingLogs.Lock() - flushes := c.pendingLogs.flushes - if len(c.pendingLogs.lines) > 0 { - flushes++ - } - c.pendingLogs.Unlock() - flushed := make(chan struct{}) - go func() { - defer close(flushed) - // Force a log flush, because with very short requests we - // may not ever flush logs. - c.flushLog(true) - }() - w.Header().Set(logFlushHeader, strconv.Itoa(flushes)) + // Flush any pending logs asynchronously. + c.pendingLogs.Lock() + flushes := c.pendingLogs.flushes + if len(c.pendingLogs.lines) > 0 { + flushes++ + } + c.pendingLogs.Unlock() + go func() { + defer close(flushed) + // Force a log flush, because with very short requests we + // may not ever flush logs. + c.flushLog(true) + }() + w.Header().Set(logFlushHeader, strconv.Itoa(flushes)) + } - // Avoid nil Write call if c.Write is never called. - if c.outCode != 0 { - w.WriteHeader(c.outCode) - } - if c.outBody != nil { - w.Write(c.outBody) - } - // Wait for the last flush to complete before returning, - // otherwise the security ticket will not be valid. - <-flushed + // Avoid nil Write call if c.Write is never called. + if c.outCode != 0 { + w.WriteHeader(c.outCode) + } + if c.outBody != nil { + w.Write(c.outBody) + } + if logToLogservice() { + // Wait for the last flush to complete before returning, + // otherwise the security ticket will not be valid. + <-flushed + } + }) } -func executeRequestSafely(c *context, r *http.Request) { - defer func() { - if x := recover(); x != nil { - logf(c, 4, "%s", renderPanic(x)) // 4 == critical - c.outCode = 500 - } - }() +func executeRequestSafelyMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if x := recover(); x != nil { + c := w.(*aeContext) + logf(c, 4, "%s", renderPanic(x)) // 4 == critical + c.outCode = 500 + } + }() - http.DefaultServeMux.ServeHTTP(c, r) + next.ServeHTTP(w, r) + }) } func renderPanic(x interface{}) string { @@ -204,9 +220,9 @@ func renderPanic(x interface{}) string { return string(buf) } -// context represents the context of an in-flight HTTP request. +// aeContext represents the aeContext of an in-flight HTTP request. // It implements the appengine.Context and http.ResponseWriter interfaces. -type context struct { +type aeContext struct { req *http.Request outCode int @@ -218,8 +234,6 @@ type context struct { lines []*logpb.UserAppLogLine flushes int } - - apiURL *url.URL } var contextKey = "holds a *context" @@ -227,8 +241,8 @@ var contextKey = "holds a *context" // jointContext joins two contexts in a superficial way. // It takes values and timeouts from a base context, and only values from another context. type jointContext struct { - base netcontext.Context - valuesOnly netcontext.Context + base context.Context + valuesOnly context.Context } func (c jointContext) Deadline() (time.Time, bool) { @@ -252,94 +266,54 @@ func (c jointContext) Value(key interface{}) interface{} { // fromContext returns the App Engine context or nil if ctx is not // derived from an App Engine context. -func fromContext(ctx netcontext.Context) *context { - c, _ := ctx.Value(&contextKey).(*context) +func fromContext(ctx context.Context) *aeContext { + c, _ := ctx.Value(&contextKey).(*aeContext) return c } -func withContext(parent netcontext.Context, c *context) netcontext.Context { - ctx := netcontext.WithValue(parent, &contextKey, c) +func withContext(parent context.Context, c *aeContext) context.Context { + ctx := context.WithValue(parent, &contextKey, c) if ns := c.req.Header.Get(curNamespaceHeader); ns != "" { ctx = withNamespace(ctx, ns) } return ctx } -func toContext(c *context) netcontext.Context { - return withContext(netcontext.Background(), c) +func toContext(c *aeContext) context.Context { + return withContext(context.Background(), c) } -func IncomingHeaders(ctx netcontext.Context) http.Header { +func IncomingHeaders(ctx context.Context) http.Header { if c := fromContext(ctx); c != nil { return c.req.Header } return nil } -func ReqContext(req *http.Request) netcontext.Context { +func ReqContext(req *http.Request) context.Context { return req.Context() } -func WithContext(parent netcontext.Context, req *http.Request) netcontext.Context { +func WithContext(parent context.Context, req *http.Request) context.Context { return jointContext{ base: parent, valuesOnly: req.Context(), } } -// DefaultTicket returns a ticket used for background context or dev_appserver. -func DefaultTicket() string { - defaultTicketOnce.Do(func() { - if IsDevAppServer() { - defaultTicket = "testapp" + defaultTicketSuffix - return - } - appID := partitionlessAppID() - escAppID := strings.Replace(strings.Replace(appID, ":", "_", -1), ".", "_", -1) - majVersion := VersionID(nil) - if i := strings.Index(majVersion, "."); i > 0 { - majVersion = majVersion[:i] - } - defaultTicket = fmt.Sprintf("%s/%s.%s.%s", escAppID, ModuleName(nil), majVersion, InstanceID()) - }) - return defaultTicket -} - -func BackgroundContext() netcontext.Context { - backgroundContextOnce.Do(func() { - // Compute background security ticket. - ticket := DefaultTicket() - - c := &context{ - req: &http.Request{ - Header: http.Header{ - ticketHeader: []string{ticket}, - }, - }, - apiURL: apiURL(), - } - backgroundContext = toContext(c) - - // TODO(dsymonds): Wire up the shutdown handler to do a final flush. - go c.logFlusher(make(chan int)) - }) - - return backgroundContext -} - // RegisterTestRequest registers the HTTP request req for testing, such that -// any API calls are sent to the provided URL. It returns a closure to delete -// the registration. +// any API calls are sent to the provided URL. // It should only be used by aetest package. -func RegisterTestRequest(req *http.Request, apiURL *url.URL, decorate func(netcontext.Context) netcontext.Context) (*http.Request, func()) { - c := &context{ - req: req, - apiURL: apiURL, - } - ctx := withContext(decorate(req.Context()), c) - req = req.WithContext(ctx) - c.req = req - return req, func() {} +func RegisterTestRequest(req *http.Request, apiURL *url.URL, appID string) *http.Request { + ctx := req.Context() + ctx = withAPIHostOverride(ctx, apiURL.Hostname()) + ctx = withAPIPortOverride(ctx, apiURL.Port()) + ctx = WithAppIDOverride(ctx, appID) + + // use the unregistered request as a placeholder so that withContext can read the headers + c := &aeContext{req: req} + c.req = req.WithContext(withContext(ctx, c)) + return c.req } var errTimeout = &CallError{ @@ -348,7 +322,7 @@ var errTimeout = &CallError{ Timeout: true, } -func (c *context) Header() http.Header { return c.outHeader } +func (c *aeContext) Header() http.Header { return c.outHeader } // Copied from $GOROOT/src/pkg/net/http/transfer.go. Some response status // codes do not permit a response body (nor response entity headers such as @@ -365,7 +339,7 @@ func bodyAllowedForStatus(status int) bool { return true } -func (c *context) Write(b []byte) (int, error) { +func (c *aeContext) Write(b []byte) (int, error) { if c.outCode == 0 { c.WriteHeader(http.StatusOK) } @@ -376,7 +350,7 @@ func (c *context) Write(b []byte) (int, error) { return len(b), nil } -func (c *context) WriteHeader(code int) { +func (c *aeContext) WriteHeader(code int) { if c.outCode != 0 { logf(c, 3, "WriteHeader called multiple times on request.") // error level return @@ -384,10 +358,11 @@ func (c *context) WriteHeader(code int) { c.outCode = code } -func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) { +func post(ctx context.Context, body []byte, timeout time.Duration) (b []byte, err error) { + apiURL := apiURL(ctx) hreq := &http.Request{ Method: "POST", - URL: c.apiURL, + URL: apiURL, Header: http.Header{ apiEndpointHeader: apiEndpointHeaderValue, apiMethodHeader: apiMethodHeaderValue, @@ -396,13 +371,16 @@ func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) }, Body: ioutil.NopCloser(bytes.NewReader(body)), ContentLength: int64(len(body)), - Host: c.apiURL.Host, - } - if info := c.req.Header.Get(dapperHeader); info != "" { - hreq.Header.Set(dapperHeader, info) + Host: apiURL.Host, } - if info := c.req.Header.Get(traceHeader); info != "" { - hreq.Header.Set(traceHeader, info) + c := fromContext(ctx) + if c != nil { + if info := c.req.Header.Get(dapperHeader); info != "" { + hreq.Header.Set(dapperHeader, info) + } + if info := c.req.Header.Get(traceHeader); info != "" { + hreq.Header.Set(traceHeader, info) + } } tr := apiHTTPClient.Transport.(*http.Transport) @@ -444,7 +422,7 @@ func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) return hrespBody, nil } -func Call(ctx netcontext.Context, service, method string, in, out proto.Message) error { +func Call(ctx context.Context, service, method string, in, out proto.Message) error { if ns := NamespaceFromContext(ctx); ns != "" { if fn, ok := NamespaceMods[service]; ok { fn(in, ns) @@ -463,15 +441,11 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) } c := fromContext(ctx) - if c == nil { - // Give a good error message rather than a panic lower down. - return errNotAppEngineContext - } // Apply transaction modifications if we're in a transaction. if t := transactionFromContext(ctx); t != nil { if t.finished { - return errors.New("transaction context has expired") + return errors.New("transaction aeContext has expired") } applyTransaction(in, &t.transaction) } @@ -487,20 +461,13 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return err } - ticket := c.req.Header.Get(ticketHeader) - // Use a test ticket under test environment. - if ticket == "" { - if appid := ctx.Value(&appIDOverrideKey); appid != nil { - ticket = appid.(string) + defaultTicketSuffix + ticket := "" + if c != nil { + ticket = c.req.Header.Get(ticketHeader) + if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" { + ticket = dri } } - // Fall back to use background ticket when the request ticket is not available in Flex or dev_appserver. - if ticket == "" { - ticket = DefaultTicket() - } - if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" { - ticket = dri - } req := &remotepb.Request{ ServiceName: &service, Method: &method, @@ -512,7 +479,7 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return err } - hrespBody, err := c.post(hreqBody, timeout) + hrespBody, err := post(ctx, hreqBody, timeout) if err != nil { return err } @@ -549,11 +516,11 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return proto.Unmarshal(res.Response, out) } -func (c *context) Request() *http.Request { +func (c *aeContext) Request() *http.Request { return c.req } -func (c *context) addLogLine(ll *logpb.UserAppLogLine) { +func (c *aeContext) addLogLine(ll *logpb.UserAppLogLine) { // Truncate long log lines. // TODO(dsymonds): Check if this is still necessary. const lim = 8 << 10 @@ -575,18 +542,20 @@ var logLevelName = map[int64]string{ 4: "CRITICAL", } -func logf(c *context, level int64, format string, args ...interface{}) { +func logf(c *aeContext, level int64, format string, args ...interface{}) { if c == nil { - panic("not an App Engine context") + panic("not an App Engine aeContext") } s := fmt.Sprintf(format, args...) s = strings.TrimRight(s, "\n") // Remove any trailing newline characters. - c.addLogLine(&logpb.UserAppLogLine{ - TimestampUsec: proto.Int64(time.Now().UnixNano() / 1e3), - Level: &level, - Message: &s, - }) - // Only duplicate log to stderr if not running on App Engine second generation + if logToLogservice() { + c.addLogLine(&logpb.UserAppLogLine{ + TimestampUsec: proto.Int64(time.Now().UnixNano() / 1e3), + Level: &level, + Message: &s, + }) + } + // Log to stdout if not deployed if !IsSecondGen() { log.Print(logLevelName[level] + ": " + s) } @@ -594,7 +563,7 @@ func logf(c *context, level int64, format string, args ...interface{}) { // flushLog attempts to flush any pending logs to the appserver. // It should not be called concurrently. -func (c *context) flushLog(force bool) (flushed bool) { +func (c *aeContext) flushLog(force bool) (flushed bool) { c.pendingLogs.Lock() // Grab up to 30 MB. We can get away with up to 32 MB, but let's be cautious. n, rem := 0, 30<<20 @@ -655,7 +624,7 @@ const ( forceFlushInterval = 60 * time.Second ) -func (c *context) logFlusher(stop <-chan int) { +func (c *aeContext) logFlusher(stop <-chan int) { lastFlush := time.Now() tick := time.NewTicker(flushInterval) for { @@ -673,6 +642,12 @@ func (c *context) logFlusher(stop <-chan int) { } } -func ContextForTesting(req *http.Request) netcontext.Context { - return toContext(&context{req: req}) +func ContextForTesting(req *http.Request) context.Context { + return toContext(&aeContext{req: req}) +} + +func logToLogservice() bool { + // TODO: replace logservice with json structured logs to $LOG_DIR/app.log.json + // where $LOG_DIR is /var/log in prod and some tmpdir in dev + return os.Getenv("LOG_TO_LOGSERVICE") != "0" } |