diff options
Diffstat (limited to 'internal/processing')
-rw-r--r-- | internal/processing/account/import.go | 374 |
1 files changed, 374 insertions, 0 deletions
diff --git a/internal/processing/account/import.go b/internal/processing/account/import.go new file mode 100644 index 000000000..200d971b8 --- /dev/null +++ b/internal/processing/account/import.go @@ -0,0 +1,374 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +package account + +import ( + "context" + "encoding/csv" + "errors" + "fmt" + "mime/multipart" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" +) + +func (p *Processor) ImportData( + ctx context.Context, + requester *gtsmodel.Account, + data *multipart.FileHeader, + importType string, + overwrite bool, +) gtserror.WithCode { + switch importType { + + case "following": + return p.importFollowing( + ctx, + requester, + data, + overwrite, + ) + + case "blocks": + return p.importBlocks( + ctx, + requester, + data, + overwrite, + ) + + default: + const text = "import type not yet supported" + return gtserror.NewErrorUnprocessableEntity(errors.New(text), text) + } +} + +func (p *Processor) importFollowing( + ctx context.Context, + requester *gtsmodel.Account, + followingData *multipart.FileHeader, + overwrite bool, +) gtserror.WithCode { + file, err := followingData.Open() + if err != nil { + err := fmt.Errorf("error opening following data file: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + defer file.Close() + + // Parse records out of the file. + records, err := csv.NewReader(file).ReadAll() + if err != nil { + err := fmt.Errorf("error reading following data file: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + + // Convert the records into a slice of barebones follows. + // + // Only TargetAccount.Username, TargetAccount.Domain, + // and ShowReblogs will be set on each Follow. + follows, err := p.converter.CSVToFollowing(ctx, records) + if err != nil { + err := fmt.Errorf("error converting records to follows: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + + // Do remaining processing of this import asynchronously. + f := importFollowingAsyncF(p, requester, follows, overwrite) + p.state.Workers.Processing.Queue.Push(f) + + return nil +} + +func importFollowingAsyncF( + p *Processor, + requester *gtsmodel.Account, + follows []*gtsmodel.Follow, + overwrite bool, +) func(context.Context) { + return func(ctx context.Context) { + // Map used to store wanted + // follow targets (if overwriting). + var wantedFollows map[string]struct{} + + if overwrite { + // If we're overwriting, we need to get current + // follow(-req)s owned by requester *before* + // making any changes, so that we can remove + // unwanted follows after we've created new ones. + prevFollows, err := p.state.DB.GetAccountFollows(ctx, requester.ID, nil) + if err != nil { + log.Errorf(ctx, "db error getting following: %v", err) + return + } + + prevFollowReqs, err := p.state.DB.GetAccountFollowRequesting(ctx, requester.ID, nil) + if err != nil { + log.Errorf(ctx, "db error getting follow requesting: %v", err) + return + } + + // Initialize new follows map. + wantedFollows = make(map[string]struct{}, len(follows)) + + // Once we've created (or tried to create) + // the required follows, go through previous + // follow(-request)s and remove unwanted ones. + defer func() { + + // AccountIDs to unfollow. + toRemove := []string{} + + // Check previous follows. + for _, prev := range prevFollows { + username := prev.TargetAccount.Username + domain := prev.TargetAccount.Domain + + _, wanted := wantedFollows[username+"@"+domain] + if !wanted { + toRemove = append(toRemove, prev.TargetAccountID) + } + } + + // Now any pending follow requests. + for _, prev := range prevFollowReqs { + username := prev.TargetAccount.Username + domain := prev.TargetAccount.Domain + + _, wanted := wantedFollows[username+"@"+domain] + if !wanted { + toRemove = append(toRemove, prev.TargetAccountID) + } + } + + // Remove each discovered + // unwanted follow. + for _, accountID := range toRemove { + if _, errWithCode := p.FollowRemove( + ctx, + requester, + accountID, + ); errWithCode != nil { + log.Errorf(ctx, "could not unfollow account: %v", errWithCode.Unwrap()) + continue + } + } + }() + } + + // Go through the follows parsed from CSV + // file, and create / update each one. + for _, follow := range follows { + var ( + // Username of the target. + username = follow.TargetAccount.Username + + // Domain of the target. + // Empty for our domain. + domain = follow.TargetAccount.Domain + + // Show reblogs on + // the new follow. + showReblogs = follow.ShowReblogs + ) + + if overwrite { + // We'll be overwriting, so store + // this new follow in our handy map. + wantedFollows[username+"@"+domain] = struct{}{} + } + + // Get the target account, dereferencing it if necessary. + targetAcct, _, err := p.federator.Dereferencer.GetAccountByUsernameDomain( + ctx, + requester.Username, + username, + domain, + ) + if err != nil { + log.Errorf(ctx, "could not retrieve account: %v", err) + continue + } + + // Use the processor's FollowCreate function + // to create or update the follow. This takes + // account of existing follows, and also sends + // the follow to the FromClientAPI processor. + if _, errWithCode := p.FollowCreate( + ctx, + requester, + &apimodel.AccountFollowRequest{ + ID: targetAcct.ID, + Reblogs: showReblogs, + }, + ); errWithCode != nil { + log.Errorf(ctx, "could not follow account: %v", errWithCode.Unwrap()) + continue + } + } + } +} + +func (p *Processor) importBlocks( + ctx context.Context, + requester *gtsmodel.Account, + blocksData *multipart.FileHeader, + overwrite bool, +) gtserror.WithCode { + file, err := blocksData.Open() + if err != nil { + err := fmt.Errorf("error opening blocks data file: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + defer file.Close() + + // Parse records out of the file. + records, err := csv.NewReader(file).ReadAll() + if err != nil { + err := fmt.Errorf("error reading blocks data file: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + + // Convert the records into a slice of barebones blocks. + // + // Only TargetAccount.Username and TargetAccount.Domain, + // will be set on each Block. + blocks, err := p.converter.CSVToBlocks(ctx, records) + if err != nil { + err := fmt.Errorf("error converting records to blocks: %w", err) + return gtserror.NewErrorBadRequest(err, err.Error()) + } + + // Do remaining processing of this import asynchronously. + f := importBlocksAsyncF(p, requester, blocks, overwrite) + p.state.Workers.Processing.Queue.Push(f) + + return nil +} + +func importBlocksAsyncF( + p *Processor, + requester *gtsmodel.Account, + blocks []*gtsmodel.Block, + overwrite bool, +) func(context.Context) { + return func(ctx context.Context) { + // Map used to store wanted + // block targets (if overwriting). + var wantedBlocks map[string]struct{} + + if overwrite { + // If we're overwriting, we need to get current + // blocks owned by requester *before* making any + // changes, so that we can remove unwanted blocks + // after we've created new ones. + var ( + prevBlocks []*gtsmodel.Block + err error + ) + + prevBlocks, err = p.state.DB.GetAccountBlocks(ctx, requester.ID, nil) + if err != nil { + log.Errorf(ctx, "db error getting blocks: %v", err) + return + } + + // Initialize new blocks map. + wantedBlocks = make(map[string]struct{}, len(blocks)) + + // Once we've created (or tried to create) + // the required blocks, go through previous + // blocks and remove unwanted ones. + defer func() { + for _, prev := range prevBlocks { + username := prev.TargetAccount.Username + domain := prev.TargetAccount.Domain + + _, wanted := wantedBlocks[username+"@"+domain] + if wanted { + // Leave this + // one alone. + continue + } + + if _, errWithCode := p.BlockRemove( + ctx, + requester, + prev.TargetAccountID, + ); errWithCode != nil { + log.Errorf(ctx, "could not unblock account: %v", errWithCode.Unwrap()) + continue + } + } + }() + } + + // Go through the blocks parsed from CSV + // file, and create / update each one. + for _, block := range blocks { + var ( + // Username of the target. + username = block.TargetAccount.Username + + // Domain of the target. + // Empty for our domain. + domain = block.TargetAccount.Domain + ) + + if overwrite { + // We'll be overwriting, so store + // this new block in our handy map. + wantedBlocks[username+"@"+domain] = struct{}{} + } + + // Get the target account, dereferencing it if necessary. + targetAcct, _, err := p.federator.Dereferencer.GetAccountByUsernameDomain( + ctx, + // Provide empty request user to use the + // instance account to deref the account. + // + // It's pointless to make lots of calls + // to a remote from an account that's about + // to block that account. + "", + username, + domain, + ) + if err != nil { + log.Errorf(ctx, "could not retrieve account: %v", err) + continue + } + + // Use the processor's BlockCreate function + // to create or update the block. This takes + // account of existing blocks, and also sends + // the block to the FromClientAPI processor. + if _, errWithCode := p.BlockCreate( + ctx, + requester, + targetAcct.ID, + ); errWithCode != nil { + log.Errorf(ctx, "could not block account: %v", errWithCode.Unwrap()) + continue + } + } + } +} |