summaryrefslogtreecommitdiff
path: root/testrig/transportcontroller.go
diff options
context:
space:
mode:
Diffstat (limited to 'testrig/transportcontroller.go')
-rw-r--r--testrig/transportcontroller.go66
1 files changed, 49 insertions, 17 deletions
diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go
index 1c75e1974..46a9b0fb2 100644
--- a/testrig/transportcontroller.go
+++ b/testrig/transportcontroller.go
@@ -78,7 +78,7 @@ type MockHTTPClient struct {
// to customize how the client is mocked.
//
// Note that you should never ever make ACTUAL http calls with this thing.
-func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relativeMediaPath string) *MockHTTPClient {
+func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relativeMediaPath string, extraPeople ...vocab.ActivityStreamsPerson) *MockHTTPClient {
mockHTTPClient := &MockHTTPClient{}
if do != nil {
@@ -95,10 +95,13 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat
mockHTTPClient.TestTombstones = NewTestTombstones()
mockHTTPClient.do = func(req *http.Request) (*http.Response, error) {
- responseCode := http.StatusNotFound
- responseBytes := []byte(`{"error":"404 not found"}`)
- responseContentType := applicationJSON
- responseContentLength := len(responseBytes)
+ var (
+ responseCode = http.StatusNotFound
+ responseBytes = []byte(`{"error":"404 not found"}`)
+ responseContentType = applicationJSON
+ responseContentLength = len(responseBytes)
+ reqURLString = req.URL.String()
+ )
if req.Method == http.MethodPost {
b, err := io.ReadAll(req.Body)
@@ -106,26 +109,26 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat
panic(err)
}
- if sI, loaded := mockHTTPClient.SentMessages.LoadOrStore(req.URL.String(), [][]byte{b}); loaded {
+ if sI, loaded := mockHTTPClient.SentMessages.LoadOrStore(reqURLString, [][]byte{b}); loaded {
s, ok := sI.([][]byte)
if !ok {
panic("SentMessages entry wasn't [][]byte")
}
s = append(s, b)
- mockHTTPClient.SentMessages.Store(req.URL.String(), s)
+ mockHTTPClient.SentMessages.Store(reqURLString, s)
}
responseCode = http.StatusOK
responseBytes = []byte(`{"ok":"accepted"}`)
responseContentType = applicationJSON
responseContentLength = len(responseBytes)
- } else if strings.Contains(req.URL.String(), ".well-known/webfinger") {
+ } else if strings.Contains(reqURLString, ".well-known/webfinger") {
responseCode, responseBytes, responseContentType, responseContentLength = WebfingerResponse(req)
- } else if strings.Contains(req.URL.String(), ".weird-webfinger-location/webfinger") {
+ } else if strings.Contains(reqURLString, ".weird-webfinger-location/webfinger") {
responseCode, responseBytes, responseContentType, responseContentLength = WebfingerResponse(req)
- } else if strings.Contains(req.URL.String(), ".well-known/host-meta") {
+ } else if strings.Contains(reqURLString, ".well-known/host-meta") {
responseCode, responseBytes, responseContentType, responseContentLength = HostMetaResponse(req)
- } else if note, ok := mockHTTPClient.TestRemoteStatuses[req.URL.String()]; ok {
+ } else if note, ok := mockHTTPClient.TestRemoteStatuses[reqURLString]; ok {
// the request is for a note that we have stored
noteI, err := streams.Serialize(note)
if err != nil {
@@ -139,7 +142,7 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat
responseBytes = noteJSON
responseContentType = applicationActivityJSON
responseContentLength = len(noteJSON)
- } else if person, ok := mockHTTPClient.TestRemotePeople[req.URL.String()]; ok {
+ } else if person, ok := mockHTTPClient.TestRemotePeople[reqURLString]; ok {
// the request is for a person that we have stored
personI, err := streams.Serialize(person)
if err != nil {
@@ -153,7 +156,7 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat
responseBytes = personJSON
responseContentType = applicationActivityJSON
responseContentLength = len(personJSON)
- } else if group, ok := mockHTTPClient.TestRemoteGroups[req.URL.String()]; ok {
+ } else if group, ok := mockHTTPClient.TestRemoteGroups[reqURLString]; ok {
// the request is for a person that we have stored
groupI, err := streams.Serialize(group)
if err != nil {
@@ -167,7 +170,7 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat
responseBytes = groupJSON
responseContentType = applicationActivityJSON
responseContentLength = len(groupJSON)
- } else if service, ok := mockHTTPClient.TestRemoteServices[req.URL.String()]; ok {
+ } else if service, ok := mockHTTPClient.TestRemoteServices[reqURLString]; ok {
serviceI, err := streams.Serialize(service)
if err != nil {
panic(err)
@@ -180,7 +183,7 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat
responseBytes = serviceJSON
responseContentType = applicationActivityJSON
responseContentLength = len(serviceJSON)
- } else if emoji, ok := mockHTTPClient.TestRemoteEmojis[req.URL.String()]; ok {
+ } else if emoji, ok := mockHTTPClient.TestRemoteEmojis[reqURLString]; ok {
emojiI, err := streams.Serialize(emoji)
if err != nil {
panic(err)
@@ -193,16 +196,45 @@ func NewMockHTTPClient(do func(req *http.Request) (*http.Response, error), relat
responseBytes = emojiJSON
responseContentType = applicationActivityJSON
responseContentLength = len(emojiJSON)
- } else if attachment, ok := mockHTTPClient.TestRemoteAttachments[req.URL.String()]; ok {
+ } else if attachment, ok := mockHTTPClient.TestRemoteAttachments[reqURLString]; ok {
responseCode = http.StatusOK
responseBytes = attachment.Data
responseContentType = attachment.ContentType
responseContentLength = len(attachment.Data)
- } else if _, ok := mockHTTPClient.TestTombstones[req.URL.String()]; ok {
+ } else if _, ok := mockHTTPClient.TestTombstones[reqURLString]; ok {
responseCode = http.StatusGone
responseBytes = []byte{}
responseContentType = "text/html"
responseContentLength = 0
+ } else {
+ for _, person := range extraPeople {
+ // For any extra people, check if the
+ // request matches one of:
+ //
+ // - Public key URI
+ // - ActivityPub URI/id
+ // - Web URL.
+ //
+ // Since this is a test environment,
+ // just assume all these values have
+ // been properly set.
+ if reqURLString == person.GetW3IDSecurityV1PublicKey().At(0).Get().GetJSONLDId().GetIRI().String() ||
+ reqURLString == person.GetJSONLDId().GetIRI().String() ||
+ reqURLString == person.GetActivityStreamsUrl().At(0).GetIRI().String() {
+ personI, err := streams.Serialize(person)
+ if err != nil {
+ panic(err)
+ }
+ personJSON, err := json.Marshal(personI)
+ if err != nil {
+ panic(err)
+ }
+ responseCode = http.StatusOK
+ responseBytes = personJSON
+ responseContentType = applicationActivityJSON
+ responseContentLength = len(personJSON)
+ }
+ }
}
log.Debugf(nil, "returning response %s", string(responseBytes))