diff options
author | 2024-08-02 13:41:46 +0200 | |
---|---|---|
committer | 2024-08-02 12:41:46 +0100 | |
commit | 7b5917d6ae48f83c92f92d7277960cfa6ae8ec56 (patch) | |
tree | 93ee6999195060714f41f9b9476d4d76ad50520c /internal | |
parent | [chore] Take account of rotation data when calculating full size image dimens... (diff) | |
download | gotosocial-7b5917d6ae48f83c92f92d7277960cfa6ae8ec56.tar.xz |
[feature] Allow import of following and blocks via CSV (#3150)
* [feature] Import follows + blocks via settings panel
* test import follows
Diffstat (limited to 'internal')
-rw-r--r-- | internal/api/auth/token_test.go | 8 | ||||
-rw-r--r-- | internal/api/client.go | 4 | ||||
-rw-r--r-- | internal/api/client/accounts/accountdelete_test.go | 6 | ||||
-rw-r--r-- | internal/api/client/accounts/accountupdate_test.go | 6 | ||||
-rw-r--r-- | internal/api/client/admin/emojicreate_test.go | 8 | ||||
-rw-r--r-- | internal/api/client/admin/emojiupdate_test.go | 22 | ||||
-rw-r--r-- | internal/api/client/import/import.go | 195 | ||||
-rw-r--r-- | internal/api/client/import/import_test.go | 210 | ||||
-rw-r--r-- | internal/api/client/instance/instancepatch_test.go | 9 | ||||
-rw-r--r-- | internal/api/client/lists/listaccountsadd_test.go | 2 | ||||
-rw-r--r-- | internal/api/client/media/mediacreate_test.go | 8 | ||||
-rw-r--r-- | internal/api/client/media/mediaupdate_test.go | 4 | ||||
-rw-r--r-- | internal/api/client/polls/polls_vote_test.go | 2 | ||||
-rw-r--r-- | internal/api/model/exportimport.go | 22 | ||||
-rw-r--r-- | internal/processing/account/import.go | 374 | ||||
-rw-r--r-- | internal/typeutils/csv.go | 135 | ||||
-rw-r--r-- | internal/workers/workers.go | 12 |
17 files changed, 992 insertions, 35 deletions
diff --git a/internal/api/auth/token_test.go b/internal/api/auth/token_test.go index c97fce3b9..1c53b5b2e 100644 --- a/internal/api/auth/token_test.go +++ b/internal/api/auth/token_test.go @@ -57,7 +57,7 @@ func (suite *TokenTestSuite) TestRetrieveClientCredentialsOK() { testClient := suite.testClients["local_account_1"] requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "grant_type": {"client_credentials"}, "client_id": {testClient.ID}, @@ -103,7 +103,7 @@ func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeOK() { testUserAuthorizationToken := suite.testTokens["local_account_1_user_authorization_token"] requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "grant_type": {"authorization_code"}, "client_id": {testClient.ID}, @@ -148,7 +148,7 @@ func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeNoCode() { testClient := suite.testClients["local_account_1"] requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "grant_type": {"authorization_code"}, "client_id": {testClient.ID}, @@ -180,7 +180,7 @@ func (suite *TokenTestSuite) TestRetrieveAuthorizationCodeWrongGrantType() { testClient := suite.testClients["local_account_1"] requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "grant_type": {"client_credentials"}, "client_id": {testClient.ID}, diff --git a/internal/api/client.go b/internal/api/client.go index 64f185430..65d4f29d5 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -35,6 +35,7 @@ import ( filtersV2 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v2" "github.com/superseriousbusiness/gotosocial/internal/api/client/followedtags" "github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests" + importdata "github.com/superseriousbusiness/gotosocial/internal/api/client/import" "github.com/superseriousbusiness/gotosocial/internal/api/client/instance" "github.com/superseriousbusiness/gotosocial/internal/api/client/interactionpolicies" "github.com/superseriousbusiness/gotosocial/internal/api/client/lists" @@ -76,6 +77,7 @@ type Client struct { filtersV2 *filtersV2.Module // api/v2/filters followRequests *followrequests.Module // api/v1/follow_requests followedTags *followedtags.Module // api/v1/followed_tags + importData *importdata.Module // api/v1/import instance *instance.Module // api/v1/instance interactionPolicies *interactionpolicies.Module // api/v1/interaction_policies lists *lists.Module // api/v1/lists @@ -125,6 +127,7 @@ func (c *Client) Route(r *router.Router, m ...gin.HandlerFunc) { c.filtersV2.Route(h) c.followRequests.Route(h) c.followedTags.Route(h) + c.importData.Route(h) c.instance.Route(h) c.interactionPolicies.Route(h) c.lists.Route(h) @@ -162,6 +165,7 @@ func NewClient(state *state.State, p *processing.Processor) *Client { filtersV2: filtersV2.New(p), followRequests: followrequests.New(p), followedTags: followedtags.New(p), + importData: importdata.New(p), instance: instance.New(p), interactionPolicies: interactionpolicies.New(p), lists: lists.New(p), diff --git a/internal/api/client/accounts/accountdelete_test.go b/internal/api/client/accounts/accountdelete_test.go index 2f5a25b4b..66a5fa097 100644 --- a/internal/api/client/accounts/accountdelete_test.go +++ b/internal/api/client/accounts/accountdelete_test.go @@ -35,7 +35,7 @@ func (suite *AccountDeleteTestSuite) TestAccountDeletePOSTHandler() { // set up the request // we're deleting zork requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "password": {"password"}, }) @@ -57,7 +57,7 @@ func (suite *AccountDeleteTestSuite) TestAccountDeletePOSTHandlerWrongPassword() // set up the request // we're deleting zork requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "password": {"aaaaaaaaaaaaaaaaaaaaaaaaaaaa"}, }) @@ -79,7 +79,7 @@ func (suite *AccountDeleteTestSuite) TestAccountDeletePOSTHandlerNoPassword() { // set up the request // we're deleting zork requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{}) if err != nil { panic(err) diff --git a/internal/api/client/accounts/accountupdate_test.go b/internal/api/client/accounts/accountupdate_test.go index 09996e998..d0def500c 100644 --- a/internal/api/client/accounts/accountupdate_test.go +++ b/internal/api/client/accounts/accountupdate_test.go @@ -51,7 +51,7 @@ func (suite *AccountUpdateTestSuite) updateAccountFromForm(data map[string][]str } func (suite *AccountUpdateTestSuite) updateAccountFromFormData(data map[string][]string, expectedHTTPStatus int, expectedBody string) (*apimodel.Account, error) { - requestBody, w, err := testrig.CreateMultipartFormData("", "", data) + requestBody, w, err := testrig.CreateMultipartFormData(nil, data) if err != nil { suite.FailNow(err.Error()) } @@ -59,8 +59,8 @@ func (suite *AccountUpdateTestSuite) updateAccountFromFormData(data map[string][ return suite.updateAccount(requestBody.Bytes(), w.FormDataContentType(), expectedHTTPStatus, expectedBody) } -func (suite *AccountUpdateTestSuite) updateAccountFromFormDataWithFile(fieldName string, fileName string, data map[string][]string, expectedHTTPStatus int, expectedBody string) (*apimodel.Account, error) { - requestBody, w, err := testrig.CreateMultipartFormData(fieldName, fileName, data) +func (suite *AccountUpdateTestSuite) updateAccountFromFormDataWithFile(fieldName string, filePath string, data map[string][]string, expectedHTTPStatus int, expectedBody string) (*apimodel.Account, error) { + requestBody, w, err := testrig.CreateMultipartFormData(testrig.FileToDataF(fieldName, filePath), data) if err != nil { suite.FailNow(err.Error()) } diff --git a/internal/api/client/admin/emojicreate_test.go b/internal/api/client/admin/emojicreate_test.go index a687fb0af..9e985459b 100644 --- a/internal/api/client/admin/emojicreate_test.go +++ b/internal/api/client/admin/emojicreate_test.go @@ -38,7 +38,7 @@ type EmojiCreateTestSuite struct { func (suite *EmojiCreateTestSuite) TestEmojiCreateNewCategory() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "image", "../../../../testrig/media/rainbow-original.png", + testrig.FileToDataF("image", "../../../../testrig/media/rainbow-original.png"), map[string][]string{ "shortcode": {"new_emoji"}, "category": {"Test Emojis"}, // this category doesn't exist yet @@ -111,7 +111,7 @@ func (suite *EmojiCreateTestSuite) TestEmojiCreateNewCategory() { func (suite *EmojiCreateTestSuite) TestEmojiCreateExistingCategory() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "image", "../../../../testrig/media/rainbow-original.png", + testrig.FileToDataF("image", "../../../../testrig/media/rainbow-original.png"), map[string][]string{ "shortcode": {"new_emoji"}, "category": {"cute stuff"}, // this category already exists @@ -184,7 +184,7 @@ func (suite *EmojiCreateTestSuite) TestEmojiCreateExistingCategory() { func (suite *EmojiCreateTestSuite) TestEmojiCreateNoCategory() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "image", "../../../../testrig/media/rainbow-original.png", + testrig.FileToDataF("image", "../../../../testrig/media/rainbow-original.png"), map[string][]string{ "shortcode": {"new_emoji"}, "category": {""}, @@ -257,7 +257,7 @@ func (suite *EmojiCreateTestSuite) TestEmojiCreateNoCategory() { func (suite *EmojiCreateTestSuite) TestEmojiCreateAlreadyExists() { // set up the request -- use a shortcode that already exists for an emoji in the database requestBody, w, err := testrig.CreateMultipartFormData( - "image", "../../../../testrig/media/rainbow-original.png", + testrig.FileToDataF("image", "../../../../testrig/media/rainbow-original.png"), map[string][]string{ "shortcode": {"rainbow"}, }) diff --git a/internal/api/client/admin/emojiupdate_test.go b/internal/api/client/admin/emojiupdate_test.go index 073e3cec0..5df43d7ae 100644 --- a/internal/api/client/admin/emojiupdate_test.go +++ b/internal/api/client/admin/emojiupdate_test.go @@ -44,7 +44,7 @@ func (suite *EmojiUpdateTestSuite) TestEmojiUpdateNewCategory() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "category": {"New Category"}, // this category doesn't exist yet "type": {"modify"}, @@ -121,7 +121,7 @@ func (suite *EmojiUpdateTestSuite) TestEmojiUpdateSwitchCategory() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "type": {"modify"}, "category": {"cute stuff"}, @@ -198,7 +198,7 @@ func (suite *EmojiUpdateTestSuite) TestEmojiUpdateCopyRemoteToLocal() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "type": {"copy"}, "category": {"emojis i stole"}, @@ -276,7 +276,7 @@ func (suite *EmojiUpdateTestSuite) TestEmojiUpdateDisableEmoji() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "type": {"disable"}, }) @@ -317,7 +317,7 @@ func (suite *EmojiUpdateTestSuite) TestEmojiUpdateDisableLocalEmoji() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "type": {"disable"}, }) @@ -350,7 +350,7 @@ func (suite *EmojiUpdateTestSuite) TestEmojiUpdateModifyRemoteEmoji() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "image", "../../../../testrig/media/kip-original.gif", + testrig.FileToDataF("image", "../../../../testrig/media/kip-original.gif"), map[string][]string{ "type": {"modify"}, }) @@ -383,7 +383,7 @@ func (suite *EmojiUpdateTestSuite) TestEmojiUpdateModifyNoParams() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "type": {"modify"}, }) @@ -416,7 +416,7 @@ func (suite *EmojiUpdateTestSuite) TestEmojiUpdateCopyLocalToLocal() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "type": {"copy"}, "shortcode": {"bottoms"}, @@ -450,7 +450,7 @@ func (suite *EmojiUpdateTestSuite) TestEmojiUpdateCopyEmptyShortcode() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "type": {"copy"}, "shortcode": {""}, @@ -484,7 +484,7 @@ func (suite *EmojiUpdateTestSuite) TestEmojiUpdateCopyNoShortcode() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "type": {"copy"}, }) @@ -517,7 +517,7 @@ func (suite *EmojiUpdateTestSuite) TestEmojiUpdateCopyShortcodeAlreadyInUse() { // set up the request requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "type": {"copy"}, "shortcode": {"rainbow"}, diff --git a/internal/api/client/import/import.go b/internal/api/client/import/import.go new file mode 100644 index 000000000..6d85a6b23 --- /dev/null +++ b/internal/api/client/import/import.go @@ -0,0 +1,195 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package importdata + +import ( + "errors" + "fmt" + "net/http" + "slices" + "strings" + + "github.com/gin-gonic/gin" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/processing" +) + +const ( + BasePath = "/v1/import" +) + +var types = []string{ + "following", + "blocks", +} + +var modes = []string{ + "merge", + "overwrite", +} + +type Module struct { + processor *processing.Processor +} + +func New(processor *processing.Processor) *Module { + return &Module{ + processor: processor, + } +} + +func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { + attachHandler(http.MethodPost, BasePath, m.ImportPOSTHandler) +} + +// ImportPOSTHandler swagger:operation POST /api/v1/import importData +// +// Upload some CSV-formatted data to your account. +// +// This can be used to migrate data from a Mastodon-compatible CSV file to a GoToSocial account. +// +// Uploaded data will be processed asynchronously, and not all entries may be processed depending +// on domain blocks, user-level blocks, network availability of referenced accounts and statuses, etc. +// +// --- +// tags: +// - import-export +// +// consumes: +// - multipart/form-data +// +// produces: +// - application/json +// +// parameters: +// - +// name: data +// in: formData +// description: The CSV data file to upload. +// type: file +// required: true +// - +// name: type +// in: formData +// description: >- +// Type of entries contained in the data file: +// +// - `following` - accounts to follow. +// - `blocks` - accounts to block. +// type: string +// required: true +// - +// name: mode +// in: formData +// description: >- +// Mode to use when creating entries from the data file: +// +// - `merge` to merge entries in file with existing entries. +// - `overwrite` to replace existing entries with entries in file. +// type: string +// default: merge +// +// security: +// - OAuth2 Bearer: +// - write:accounts +// +// responses: +// '202': +// description: Upload accepted. +// '400': +// description: bad request +// '401': +// description: unauthorized +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) ImportPOSTHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if authed.Account.IsMoving() { + apiutil.ForbiddenAfterMove(c) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + form := &apimodel.ImportRequest{} + if err := c.ShouldBind(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if form.Data == nil { + const text = "no data file provided" + err := errors.New(text) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, text), m.processor.InstanceGetV1) + return + } + + if form.Type == "" { + const text = "no type provided" + err := errors.New(text) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, text), m.processor.InstanceGetV1) + return + } + + form.Type = strings.ToLower(form.Type) + if !slices.Contains(types, form.Type) { + text := fmt.Sprintf("type %s not recognized, valid types are: %+v", form.Type, types) + err := errors.New(text) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, text), m.processor.InstanceGetV1) + return + } + + if form.Mode != "" { + form.Mode = strings.ToLower(form.Mode) + if !slices.Contains(modes, form.Mode) { + text := fmt.Sprintf("mode %s not recognized, valid modes are: %+v", form.Mode, modes) + err := errors.New(text) + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, text), m.processor.InstanceGetV1) + return + } + } + overwrite := form.Mode == "overwrite" + + // Trigger the import. + errWithCode := m.processor.Account().ImportData( + c.Request.Context(), + authed.Account, + form.Data, + form.Type, + overwrite, + ) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiutil.JSON(c, http.StatusAccepted, gin.H{"status": "accepted"}) +} diff --git a/internal/api/client/import/import_test.go b/internal/api/client/import/import_test.go new file mode 100644 index 000000000..5129f862e --- /dev/null +++ b/internal/api/client/import/import_test.go @@ -0,0 +1,210 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package importdata_test + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + importdata "github.com/superseriousbusiness/gotosocial/internal/api/client/import" + "github.com/superseriousbusiness/gotosocial/internal/filter/visibility" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type ImportTestSuite struct { + // Suite interfaces + suite.Suite + state state.State + + // standard suite models + testTokens map[string]*gtsmodel.Token + testClients map[string]*gtsmodel.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + + // module being tested + importModule *importdata.Module +} + +func (suite *ImportTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() +} + +func (suite *ImportTestSuite) SetupTest() { + suite.state.Caches.Init() + + testrig.InitTestConfig() + testrig.InitTestLog() + + suite.state.DB = testrig.NewTestDB(&suite.state) + suite.state.Storage = testrig.NewInMemoryStorage() + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + typeutils.NewConverter(&suite.state), + ) + + testrig.StandardDBSetup(suite.state.DB, nil) + testrig.StandardStorageSetup(suite.state.Storage, "../../../../testrig/media") + + mediaManager := testrig.NewTestMediaManager(&suite.state) + + federator := testrig.NewTestFederator( + &suite.state, + testrig.NewTestTransportController( + &suite.state, + testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), + ), + mediaManager, + ) + + processor := testrig.NewTestProcessor( + &suite.state, + federator, + testrig.NewEmailSender("../../../../web/template/", nil), + mediaManager, + ) + testrig.StartWorkers(&suite.state, processor.Workers()) + + suite.importModule = importdata.New(processor) +} + +func (suite *ImportTestSuite) TriggerHandler( + importData string, + importType string, + importMode string, +) { + // Set up request. + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + + // Authorize the request ctx as though it + // had passed through API auth handlers. + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // Create test request. + b, w, err := testrig.CreateMultipartFormData( + testrig.StringToDataF("data", "data.csv", importData), + map[string][]string{ + "type": {importType}, + "mode": {importMode}, + }, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + target := "http://localhost:8080/api/v1/import" + ctx.Request = httptest.NewRequest(http.MethodPost, target, bytes.NewReader(b.Bytes())) + ctx.Request.Header.Set("Accept", "application/json") + ctx.Request.Header.Set("Content-Type", w.FormDataContentType()) + + // Trigger handler. + suite.importModule.ImportPOSTHandler(ctx) + + if code := recorder.Code; code != http.StatusAccepted { + b, err := io.ReadAll(recorder.Body) + if err != nil { + panic(err) + } + suite.FailNow("", "expected 202, got %d: %s", code, string(b)) + } +} + +func (suite *ImportTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.state.DB) + testrig.StandardStorageTeardown(suite.state.Storage) + testrig.StopWorkers(&suite.state) +} + +func (suite *ImportTestSuite) TestImportFollows() { + var ( + ctx = context.Background() + testAccount = suite.testAccounts["local_account_1"] + ) + + // Clear existing follows from Zork. + if err := suite.state.DB.DeleteAccountFollows(ctx, testAccount.ID); err != nil { + suite.FailNow(err.Error()) + } + + // Have zork refollow turtle and admin. + data := `Account address,Show boosts +admin@localhost:8080,true +1happyturtle@localhost:8080,true +` + + // Trigger the import handler. + suite.TriggerHandler(data, "following", "merge") + + // Wait for zork to be + // following admin. + if !testrig.WaitFor(func() bool { + f, err := suite.state.DB.IsFollowing( + ctx, + testAccount.ID, + suite.testAccounts["admin_account"].ID, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + return f + }) { + suite.FailNow("timed out waiting for zork to follow admin") + } + + // Wait for zork to be + // follow req'ing turtle. + if !testrig.WaitFor(func() bool { + f, err := suite.state.DB.IsFollowRequested( + ctx, + testAccount.ID, + suite.testAccounts["local_account_2"].ID, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + return f + }) { + suite.FailNow("timed out waiting for zork to follow req turtle") + } +} + +func TestImportTestSuite(t *testing.T) { + suite.Run(t, new(ImportTestSuite)) +} diff --git a/internal/api/client/instance/instancepatch_test.go b/internal/api/client/instance/instancepatch_test.go index f12638f82..bb391537e 100644 --- a/internal/api/client/instance/instancepatch_test.go +++ b/internal/api/client/instance/instancepatch_test.go @@ -37,7 +37,12 @@ type InstancePatchTestSuite struct { } func (suite *InstancePatchTestSuite) instancePatch(fieldName string, fileName string, extraFields map[string][]string) (code int, body []byte) { - requestBody, w, err := testrig.CreateMultipartFormData(fieldName, fileName, extraFields) + var dataF testrig.DataF + if fieldName != "" && fileName != "" { + dataF = testrig.FileToDataF(fieldName, fileName) + } + + requestBody, w, err := testrig.CreateMultipartFormData(dataF, extraFields) if err != nil { suite.FailNow(err.Error()) } @@ -499,7 +504,7 @@ func (suite *InstancePatchTestSuite) TestInstancePatch4() { func (suite *InstancePatchTestSuite) TestInstancePatch5() { requestBody, w, err := testrig.CreateMultipartFormData( - "", "", + nil, map[string][]string{ "short_description": {"<p>This is some html, which is <em>allowed</em> in short descriptions.</p>"}, }) diff --git a/internal/api/client/lists/listaccountsadd_test.go b/internal/api/client/lists/listaccountsadd_test.go index 492996882..7e44eeed3 100644 --- a/internal/api/client/lists/listaccountsadd_test.go +++ b/internal/api/client/lists/listaccountsadd_test.go @@ -60,7 +60,7 @@ func (suite *ListAccountsAddTestSuite) postListAccounts( requestPath := config.GetProtocol() + "://" + config.GetHost() + "/api/" + lists.BasePath + "/" + listID + "/accounts" // Prepare test body. - buf, w, err := testrig.CreateMultipartFormData("", "", map[string][]string{ + buf, w, err := testrig.CreateMultipartFormData(nil, map[string][]string{ "account_ids[]": accountIDs, }) diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go index c256d18dc..4c2725681 100644 --- a/internal/api/client/media/mediacreate_test.go +++ b/internal/api/client/media/mediacreate_test.go @@ -149,7 +149,7 @@ func (suite *MediaCreateTestSuite) TestMediaCreateSuccessful() { } // create the request - buf, w, err := testrig.CreateMultipartFormData("file", "../../../../testrig/media/test-jpeg.jpg", map[string][]string{ + buf, w, err := testrig.CreateMultipartFormData(testrig.FileToDataF("file", "../../../../testrig/media/test-jpeg.jpg"), map[string][]string{ "description": {"this is a test image -- a cool background from somewhere"}, "focus": {"-0.5,0.5"}, }) @@ -234,7 +234,7 @@ func (suite *MediaCreateTestSuite) TestMediaCreateSuccessfulV2() { } // create the request - buf, w, err := testrig.CreateMultipartFormData("file", "../../../../testrig/media/test-jpeg.jpg", map[string][]string{ + buf, w, err := testrig.CreateMultipartFormData(testrig.FileToDataF("file", "../../../../testrig/media/test-jpeg.jpg"), map[string][]string{ "description": {"this is a test image -- a cool background from somewhere"}, "focus": {"-0.5,0.5"}, }) @@ -317,7 +317,7 @@ func (suite *MediaCreateTestSuite) TestMediaCreateLongDescription() { description := base64.RawStdEncoding.EncodeToString(descriptionBytes) // create the request - buf, w, err := testrig.CreateMultipartFormData("file", "../../../../testrig/media/test-jpeg.jpg", map[string][]string{ + buf, w, err := testrig.CreateMultipartFormData(testrig.FileToDataF("file", "../../../../testrig/media/test-jpeg.jpg"), map[string][]string{ "description": {description}, "focus": {"-0.5,0.5"}, }) @@ -358,7 +358,7 @@ func (suite *MediaCreateTestSuite) TestMediaCreateTooShortDescription() { ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) // create the request - buf, w, err := testrig.CreateMultipartFormData("file", "../../../../testrig/media/test-jpeg.jpg", map[string][]string{ + buf, w, err := testrig.CreateMultipartFormData(testrig.FileToDataF("file", "../../../../testrig/media/test-jpeg.jpg"), map[string][]string{ "description": {""}, // provide an empty description "focus": {"-0.5,0.5"}, }) diff --git a/internal/api/client/media/mediaupdate_test.go b/internal/api/client/media/mediaupdate_test.go index 43b2b6c51..c3a1fb340 100644 --- a/internal/api/client/media/mediaupdate_test.go +++ b/internal/api/client/media/mediaupdate_test.go @@ -140,7 +140,7 @@ func (suite *MediaUpdateTestSuite) TestUpdateImage() { ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) // create the request - buf, w, err := testrig.CreateMultipartFormData("", "", map[string][]string{ + buf, w, err := testrig.CreateMultipartFormData(nil, map[string][]string{ "id": {toUpdate.ID}, "description": {"new description!"}, "focus": {"-0.1,0.3"}, @@ -201,7 +201,7 @@ func (suite *MediaUpdateTestSuite) TestUpdateImageShortDescription() { ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) // create the request - buf, w, err := testrig.CreateMultipartFormData("", "", map[string][]string{ + buf, w, err := testrig.CreateMultipartFormData(nil, map[string][]string{ "id": {toUpdate.ID}, "description": {"new description!"}, "focus": {"-0.1,0.3"}, diff --git a/internal/api/client/polls/polls_vote_test.go b/internal/api/client/polls/polls_vote_test.go index 01bd941d3..54f98c192 100644 --- a/internal/api/client/polls/polls_vote_test.go +++ b/internal/api/client/polls/polls_vote_test.go @@ -107,7 +107,7 @@ func (suite *PollCreateTestSuite) formVoteInPoll( choicesStrs = append(choicesStrs, strconv.Itoa(choice)) } - body, w, err := testrig.CreateMultipartFormData("", "", map[string][]string{ + body, w, err := testrig.CreateMultipartFormData(nil, map[string][]string{ "choices[]": choicesStrs, }) diff --git a/internal/api/model/exportimport.go b/internal/api/model/exportimport.go index d87ed8cd3..88ea5489d 100644 --- a/internal/api/model/exportimport.go +++ b/internal/api/model/exportimport.go @@ -17,6 +17,8 @@ package model +import "mime/multipart" + // AccountExportStats models an account's stats // specifically for the purpose of informing about // export sizes at the /api/v1/exports/stats endpoint. @@ -58,3 +60,23 @@ type AccountExportStats struct { // example: 11 MutesCount int `json:"mutes_count"` } + +// AttachmentRequest models media attachment creation parameters. +// +// swagger: ignore +type ImportRequest struct { + // The CSV data to upload. + Data *multipart.FileHeader `form:"data" binding:"required"` + // Type of entries contained in the data file. + // + // - `following` - accounts to follow. + // - `lists` - lists of accounts. + // - `blocks` - accounts to block. + // - `mutes` - accounts to mute. + // - `bookmarks` - statuses to bookmark. + Type string `form:"type" binding:"required"` + // Mode to use when creating entries from the data file: + // - `merge` to merge entries in file with existing entries. + // - `overwrite` to replace existing entries with entries in file. + Mode string `form:"mode"` +} diff --git a/internal/processing/account/import.go b/internal/processing/account/import.go new file mode 100644 index 000000000..200d971b8 --- /dev/null +++ b/internal/processing/account/import.go @@ -0,0 +1,374 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package account + +import ( + "context" + "encoding/csv" + "errors" + "fmt" + "mime/multipart" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" +) + +func (p *Processor) ImportData( + ctx context.Context, + requester *gtsmodel.Account, + data *multipart.FileHeader, + importType string, + overwrite bool, +) gtserror.WithCode { + switch importType { + + case "following": + return p.importFollowing( + ctx, + requester, + data, + overwrite, + ) + + case "blocks": + return p.importBlocks( + ctx, + requester, + data, + overwrite, + ) + + default: + const text = "import type not yet supported" + return gtserror.NewErrorUnprocessableEntity(errors.New(text), text) + } +} + +func (p *Processor) importFollowing( + ctx context.Context, + requester *gtsmodel.Account, + followingData *multipart.FileHeader, + overwrite bool, +) gtserror.WithCode { + file, err := followingData.Open() + if err != nil { + err := fmt.Errorf("error opening following data file: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + defer file.Close() + + // Parse records out of the file. + records, err := csv.NewReader(file).ReadAll() + if err != nil { + err := fmt.Errorf("error reading following data file: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + + // Convert the records into a slice of barebones follows. + // + // Only TargetAccount.Username, TargetAccount.Domain, + // and ShowReblogs will be set on each Follow. + follows, err := p.converter.CSVToFollowing(ctx, records) + if err != nil { + err := fmt.Errorf("error converting records to follows: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + + // Do remaining processing of this import asynchronously. + f := importFollowingAsyncF(p, requester, follows, overwrite) + p.state.Workers.Processing.Queue.Push(f) + + return nil +} + +func importFollowingAsyncF( + p *Processor, + requester *gtsmodel.Account, + follows []*gtsmodel.Follow, + overwrite bool, +) func(context.Context) { + return func(ctx context.Context) { + // Map used to store wanted + // follow targets (if overwriting). + var wantedFollows map[string]struct{} + + if overwrite { + // If we're overwriting, we need to get current + // follow(-req)s owned by requester *before* + // making any changes, so that we can remove + // unwanted follows after we've created new ones. + prevFollows, err := p.state.DB.GetAccountFollows(ctx, requester.ID, nil) + if err != nil { + log.Errorf(ctx, "db error getting following: %v", err) + return + } + + prevFollowReqs, err := p.state.DB.GetAccountFollowRequesting(ctx, requester.ID, nil) + if err != nil { + log.Errorf(ctx, "db error getting follow requesting: %v", err) + return + } + + // Initialize new follows map. + wantedFollows = make(map[string]struct{}, len(follows)) + + // Once we've created (or tried to create) + // the required follows, go through previous + // follow(-request)s and remove unwanted ones. + defer func() { + + // AccountIDs to unfollow. + toRemove := []string{} + + // Check previous follows. + for _, prev := range prevFollows { + username := prev.TargetAccount.Username + domain := prev.TargetAccount.Domain + + _, wanted := wantedFollows[username+"@"+domain] + if !wanted { + toRemove = append(toRemove, prev.TargetAccountID) + } + } + + // Now any pending follow requests. + for _, prev := range prevFollowReqs { + username := prev.TargetAccount.Username + domain := prev.TargetAccount.Domain + + _, wanted := wantedFollows[username+"@"+domain] + if !wanted { + toRemove = append(toRemove, prev.TargetAccountID) + } + } + + // Remove each discovered + // unwanted follow. + for _, accountID := range toRemove { + if _, errWithCode := p.FollowRemove( + ctx, + requester, + accountID, + ); errWithCode != nil { + log.Errorf(ctx, "could not unfollow account: %v", errWithCode.Unwrap()) + continue + } + } + }() + } + + // Go through the follows parsed from CSV + // file, and create / update each one. + for _, follow := range follows { + var ( + // Username of the target. + username = follow.TargetAccount.Username + + // Domain of the target. + // Empty for our domain. + domain = follow.TargetAccount.Domain + + // Show reblogs on + // the new follow. + showReblogs = follow.ShowReblogs + ) + + if overwrite { + // We'll be overwriting, so store + // this new follow in our handy map. + wantedFollows[username+"@"+domain] = struct{}{} + } + + // Get the target account, dereferencing it if necessary. + targetAcct, _, err := p.federator.Dereferencer.GetAccountByUsernameDomain( + ctx, + requester.Username, + username, + domain, + ) + if err != nil { + log.Errorf(ctx, "could not retrieve account: %v", err) + continue + } + + // Use the processor's FollowCreate function + // to create or update the follow. This takes + // account of existing follows, and also sends + // the follow to the FromClientAPI processor. + if _, errWithCode := p.FollowCreate( + ctx, + requester, + &apimodel.AccountFollowRequest{ + ID: targetAcct.ID, + Reblogs: showReblogs, + }, + ); errWithCode != nil { + log.Errorf(ctx, "could not follow account: %v", errWithCode.Unwrap()) + continue + } + } + } +} + +func (p *Processor) importBlocks( + ctx context.Context, + requester *gtsmodel.Account, + blocksData *multipart.FileHeader, + overwrite bool, +) gtserror.WithCode { + file, err := blocksData.Open() + if err != nil { + err := fmt.Errorf("error opening blocks data file: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + defer file.Close() + + // Parse records out of the file. + records, err := csv.NewReader(file).ReadAll() + if err != nil { + err := fmt.Errorf("error reading blocks data file: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + + // Convert the records into a slice of barebones blocks. + // + // Only TargetAccount.Username and TargetAccount.Domain, + // will be set on each Block. + blocks, err := p.converter.CSVToBlocks(ctx, records) + if err != nil { + err := fmt.Errorf("error converting records to blocks: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + + // Do remaining processing of this import asynchronously. + f := importBlocksAsyncF(p, requester, blocks, overwrite) + p.state.Workers.Processing.Queue.Push(f) + + return nil +} + +func importBlocksAsyncF( + p *Processor, + requester *gtsmodel.Account, + blocks []*gtsmodel.Block, + overwrite bool, +) func(context.Context) { + return func(ctx context.Context) { + // Map used to store wanted + // block targets (if overwriting). + var wantedBlocks map[string]struct{} + + if overwrite { + // If we're overwriting, we need to get current + // blocks owned by requester *before* making any + // changes, so that we can remove unwanted blocks + // after we've created new ones. + var ( + prevBlocks []*gtsmodel.Block + err error + ) + + prevBlocks, err = p.state.DB.GetAccountBlocks(ctx, requester.ID, nil) + if err != nil { + log.Errorf(ctx, "db error getting blocks: %v", err) + return + } + + // Initialize new blocks map. + wantedBlocks = make(map[string]struct{}, len(blocks)) + + // Once we've created (or tried to create) + // the required blocks, go through previous + // blocks and remove unwanted ones. + defer func() { + for _, prev := range prevBlocks { + username := prev.TargetAccount.Username + domain := prev.TargetAccount.Domain + + _, wanted := wantedBlocks[username+"@"+domain] + if wanted { + // Leave this + // one alone. + continue + } + + if _, errWithCode := p.BlockRemove( + ctx, + requester, + prev.TargetAccountID, + ); errWithCode != nil { + log.Errorf(ctx, "could not unblock account: %v", errWithCode.Unwrap()) + continue + } + } + }() + } + + // Go through the blocks parsed from CSV + // file, and create / update each one. + for _, block := range blocks { + var ( + // Username of the target. + username = block.TargetAccount.Username + + // Domain of the target. + // Empty for our domain. + domain = block.TargetAccount.Domain + ) + + if overwrite { + // We'll be overwriting, so store + // this new block in our handy map. + wantedBlocks[username+"@"+domain] = struct{}{} + } + + // Get the target account, dereferencing it if necessary. + targetAcct, _, err := p.federator.Dereferencer.GetAccountByUsernameDomain( + ctx, + // Provide empty request user to use the + // instance account to deref the account. + // + // It's pointless to make lots of calls + // to a remote from an account that's about + // to block that account. + "", + username, + domain, + ) + if err != nil { + log.Errorf(ctx, "could not retrieve account: %v", err) + continue + } + + // Use the processor's BlockCreate function + // to create or update the block. This takes + // account of existing blocks, and also sends + // the block to the FromClientAPI processor. + if _, errWithCode := p.BlockCreate( + ctx, + requester, + targetAcct.ID, + ); errWithCode != nil { + log.Errorf(ctx, "could not block account: %v", errWithCode.Unwrap()) + continue + } + } + } +} diff --git a/internal/typeutils/csv.go b/internal/typeutils/csv.go index 2ef56cb0c..063e31d54 100644 --- a/internal/typeutils/csv.go +++ b/internal/typeutils/csv.go @@ -26,6 +26,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" ) func (c *Converter) AccountToExportStats( @@ -383,3 +384,137 @@ func (c *Converter) MutesToCSV( return records, nil } + +// CSVToFollowing converts a slice of CSV records +// to a slice of barebones *gtsmodel.Follow's, +// ready for further processing. +// +// Only TargetAccount.Username, TargetAccount.Domain, +// and ShowReblogs will be set on each Follow. +func (c *Converter) CSVToFollowing( + ctx context.Context, + records [][]string, +) ([]*gtsmodel.Follow, error) { + // We need to know our own domain for this. + // Try account domain, fall back to host. + var ( + thisHost = config.GetHost() + thisAccountDomain = config.GetAccountDomain() + follows = make([]*gtsmodel.Follow, 0, len(records)) + ) + + for _, record := range records { + if len(record) != 2 { + // Badly formatted, + // skip this one. + continue + } + + namestring := record[0] + if namestring == "" { + // Badly formatted, + // skip this one. + continue + } + + // Prepend with "@" + // if not included. + if namestring[0] != '@' { + namestring = "@" + namestring + } + + username, domain, err := util.ExtractNamestringParts(namestring) + if err != nil { + // Badly formatted, + // skip this one. + continue + } + + if domain == thisHost || domain == thisAccountDomain { + // Clear the domain, + // since it's ours. + domain = "" + } + + showReblogs, err := strconv.ParseBool(record[1]) + if err != nil { + // Badly formatted, + // skip this one. + continue + } + + // Looks good, whack it in the slice. + follows = append(follows, >smodel.Follow{ + TargetAccount: >smodel.Account{ + Username: username, + Domain: domain, + }, + ShowReblogs: &showReblogs, + }) + } + + return follows, nil +} + +// CSVToBlocks converts a slice of CSV records +// to a slice of barebones *gtsmodel.Block's, +// ready for further processing. +// +// Only TargetAccount.Username and TargetAccount.Domain +// will be set on each Block. +func (c *Converter) CSVToBlocks( + ctx context.Context, + records [][]string, +) ([]*gtsmodel.Block, error) { + // We need to know our own domain for this. + // Try account domain, fall back to host. + var ( + thisHost = config.GetHost() + thisAccountDomain = config.GetAccountDomain() + blocks = make([]*gtsmodel.Block, 0, len(records)) + ) + + for _, record := range records { + if len(record) != 1 { + // Badly formatted, + // skip this one. + continue + } + + namestring := record[0] + if namestring == "" { + // Badly formatted, + // skip this one. + continue + } + + // Prepend with "@" + // if not included. + if namestring[0] != '@' { + namestring = "@" + namestring + } + + username, domain, err := util.ExtractNamestringParts(namestring) + if err != nil { + // Badly formatted, + // skip this one. + continue + } + + if domain == thisHost || domain == thisAccountDomain { + // Clear the domain, + // since it's ours. + domain = "" + } + + // Looks good, whack it in the slice. + blocks = append(blocks, >smodel.Block{ + TargetAccount: >smodel.Account{ + Username: username, + Domain: domain, + }, + }) + } + + return blocks, nil +} diff --git a/internal/workers/workers.go b/internal/workers/workers.go index 377a9d899..657522903 100644 --- a/internal/workers/workers.go +++ b/internal/workers/workers.go @@ -49,6 +49,11 @@ type Workers struct { // for asynchronous dereferencer jobs. Dereference FnWorkerPool + // Processing provides a worker pool + // for asynchronous processing jobs, + // eg., import tasks, admin tasks. + Processing FnWorkerPool + // prevent pass-by-value. _ nocopy } @@ -81,6 +86,10 @@ func (w *Workers) Start() { n = 4 * maxprocs w.Dereference.Start(n) log.Infof(nil, "started %d dereference workers", n) + + n = 4 * maxprocs + w.Processing.Start(n) + log.Infof(nil, "started %d processing workers", n) } // Stop will stop all of the contained @@ -101,6 +110,9 @@ func (w *Workers) Stop() { w.Dereference.Stop() log.Info(nil, "stopped dereference workers") + + w.Processing.Stop() + log.Info(nil, "stopped processing workers") } // nocopy when embedded will signal linter to |