diff options
Diffstat (limited to 'testrig/transportcontroller.go')
-rw-r--r-- | testrig/transportcontroller.go | 66 |
1 files changed, 49 insertions, 17 deletions
diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go index 1c75e1974..46a9b0fb2 100644 --- a/testrig/transportcontroller.go +++ b/testrig/transportcontroller.go @@ -78,7 +78,7 @@ type MockHTTPClient struct { // to customize how the client is mocked. // // Note that you should never ever make ACTUAL http calls with this thing. -func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relativeMediaPath string) *MockHTTPClient { +func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relativeMediaPath string, extraPeople ...vocab.ActivityStreamsPerson) *MockHTTPClient { mockHTTPClient := &MockHTTPClient{} if do != nil { @@ -95,10 +95,13 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat mockHTTPClient.TestTombstones = NewTestTombstones() mockHTTPClient.do = func(req *http.Request) (*http.Response, error) { - responseCode := http.StatusNotFound - responseBytes := []byte(`{"error":"404 not found"}`) - responseContentType := applicationJSON - responseContentLength := len(responseBytes) + var ( + responseCode = http.StatusNotFound + responseBytes = []byte(`{"error":"404 not found"}`) + responseContentType = applicationJSON + responseContentLength = len(responseBytes) + reqURLString = req.URL.String() + ) if req.Method == http.MethodPost { b, err := io.ReadAll(req.Body) @@ -106,26 +109,26 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat panic(err) } - if sI, loaded := mockHTTPClient.SentMessages.LoadOrStore(req.URL.String(), [][]byte{b}); loaded { + if sI, loaded := mockHTTPClient.SentMessages.LoadOrStore(reqURLString, [][]byte{b}); loaded { s, ok := sI.([][]byte) if !ok { panic("SentMessages entry wasn't [][]byte") } s = append(s, b) - mockHTTPClient.SentMessages.Store(req.URL.String(), s) + mockHTTPClient.SentMessages.Store(reqURLString, s) } responseCode = http.StatusOK responseBytes = []byte(`{"ok":"accepted"}`) responseContentType = applicationJSON responseContentLength = len(responseBytes) - } else if strings.Contains(req.URL.String(), ".well-known/webfinger") { + } else if strings.Contains(reqURLString, ".well-known/webfinger") { responseCode, responseBytes, responseContentType, responseContentLength = WebfingerResponse(req) - } else if strings.Contains(req.URL.String(), ".weird-webfinger-location/webfinger") { + } else if strings.Contains(reqURLString, ".weird-webfinger-location/webfinger") { responseCode, responseBytes, responseContentType, responseContentLength = WebfingerResponse(req) - } else if strings.Contains(req.URL.String(), ".well-known/host-meta") { + } else if strings.Contains(reqURLString, ".well-known/host-meta") { responseCode, responseBytes, responseContentType, responseContentLength = HostMetaResponse(req) - } else if note, ok := mockHTTPClient.TestRemoteStatuses[req.URL.String()]; ok { + } else if note, ok := mockHTTPClient.TestRemoteStatuses[reqURLString]; ok { // the request is for a note that we have stored noteI, err := streams.Serialize(note) if err != nil { @@ -139,7 +142,7 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat responseBytes = noteJSON responseContentType = applicationActivityJSON responseContentLength = len(noteJSON) - } else if person, ok := mockHTTPClient.TestRemotePeople[req.URL.String()]; ok { + } else if person, ok := mockHTTPClient.TestRemotePeople[reqURLString]; ok { // the request is for a person that we have stored personI, err := streams.Serialize(person) if err != nil { @@ -153,7 +156,7 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat responseBytes = personJSON responseContentType = applicationActivityJSON responseContentLength = len(personJSON) - } else if group, ok := mockHTTPClient.TestRemoteGroups[req.URL.String()]; ok { + } else if group, ok := mockHTTPClient.TestRemoteGroups[reqURLString]; ok { // the request is for a person that we have stored groupI, err := streams.Serialize(group) if err != nil { @@ -167,7 +170,7 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat responseBytes = groupJSON responseContentType = applicationActivityJSON responseContentLength = len(groupJSON) - } else if service, ok := mockHTTPClient.TestRemoteServices[req.URL.String()]; ok { + } else if service, ok := mockHTTPClient.TestRemoteServices[reqURLString]; ok { serviceI, err := streams.Serialize(service) if err != nil { panic(err) @@ -180,7 +183,7 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat responseBytes = serviceJSON responseContentType = applicationActivityJSON responseContentLength = len(serviceJSON) - } else if emoji, ok := mockHTTPClient.TestRemoteEmojis[req.URL.String()]; ok { + } else if emoji, ok := mockHTTPClient.TestRemoteEmojis[reqURLString]; ok { emojiI, err := streams.Serialize(emoji) if err != nil { panic(err) @@ -193,16 +196,45 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat responseBytes = emojiJSON responseContentType = applicationActivityJSON responseContentLength = len(emojiJSON) - } else if attachment, ok := mockHTTPClient.TestRemoteAttachments[req.URL.String()]; ok { + } else if attachment, ok := mockHTTPClient.TestRemoteAttachments[reqURLString]; ok { responseCode = http.StatusOK responseBytes = attachment.Data responseContentType = attachment.ContentType responseContentLength = len(attachment.Data) - } else if _, ok := mockHTTPClient.TestTombstones[req.URL.String()]; ok { + } else if _, ok := mockHTTPClient.TestTombstones[reqURLString]; ok { responseCode = http.StatusGone responseBytes = []byte{} responseContentType = "text/html" responseContentLength = 0 + } else { + for _, person := range extraPeople { + // For any extra people, check if the + // request matches one of: + // + // - Public key URI + // - ActivityPub URI/id + // - Web URL. + // + // Since this is a test environment, + // just assume all these values have + // been properly set. + if reqURLString == person.GetW3IDSecurityV1PublicKey().At(0).Get().GetJSONLDId().GetIRI().String() || + reqURLString == person.GetJSONLDId().GetIRI().String() || + reqURLString == person.GetActivityStreamsUrl().At(0).GetIRI().String() { + personI, err := streams.Serialize(person) + if err != nil { + panic(err) + } + personJSON, err := json.Marshal(personI) + if err != nil { + panic(err) + } + responseCode = http.StatusOK + responseBytes = personJSON + responseContentType = applicationActivityJSON + responseContentLength = len(personJSON) + } + } } log.Debugf(nil, "returning response %s", string(responseBytes)) |