diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/client/instance/instance.go | 1 | ||||
| -rw-r--r-- | internal/api/client/instance/instancepatch.go | 50 | ||||
| -rw-r--r-- | internal/api/model/instance.go | 16 | ||||
| -rw-r--r-- | internal/db/db.go | 8 | ||||
| -rw-r--r-- | internal/db/pg/instancestats.go | 52 | ||||
| -rw-r--r-- | internal/gtsmodel/instance.go | 2 | ||||
| -rw-r--r-- | internal/processing/instance.go | 116 | ||||
| -rw-r--r-- | internal/processing/processor.go | 4 | ||||
| -rw-r--r-- | internal/typeutils/internaltofrontend.go | 33 | ||||
| -rw-r--r-- | internal/util/regexes.go | 5 | ||||
| -rw-r--r-- | internal/util/validation.go | 57 | 
11 files changed, 329 insertions, 15 deletions
diff --git a/internal/api/client/instance/instance.go b/internal/api/client/instance/instance.go index 7fb08f29c..a5becf97d 100644 --- a/internal/api/client/instance/instance.go +++ b/internal/api/client/instance/instance.go @@ -34,5 +34,6 @@ func New(config *config.Config, processor processing.Processor, log *logrus.Logg  // Route satisfies the ClientModule interface  func (m *Module) Route(s router.Router) error {  	s.AttachHandler(http.MethodGet, InstanceInformationPath, m.InstanceInformationGETHandler) +	s.AttachHandler(http.MethodPatch, InstanceInformationPath, m.InstanceUpdatePATCHHandler)  	return nil  } diff --git a/internal/api/client/instance/instancepatch.go b/internal/api/client/instance/instancepatch.go new file mode 100644 index 000000000..ace7674c0 --- /dev/null +++ b/internal/api/client/instance/instancepatch.go @@ -0,0 +1,50 @@ +package instance + +import ( +	"net/http" + +	"github.com/gin-gonic/gin" +	"github.com/superseriousbusiness/gotosocial/internal/api/model" +	"github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +func (m *Module) InstanceUpdatePATCHHandler(c *gin.Context) { +	l := m.log.WithField("func", "InstanceUpdatePATCHHandler") +	authed, err := oauth.Authed(c, true, true, true, true) +	if err != nil { +		l.Debugf("couldn't auth: %s", err) +		c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) +		return +	} + +	// only admins can update instance settings +	if !authed.User.Admin { +		l.Debug("user is not an admin so cannot update instance settings") +		c.JSON(http.StatusUnauthorized, gin.H{"error": "not an admin"}) +		return +	} + +	l.Debugf("parsing request form %s", c.Request.Form) +	form := &model.InstanceSettingsUpdateRequest{} +	if err := c.ShouldBind(&form); err != nil || form == nil { +		l.Debugf("could not parse form from request: %s", err) +		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +		return +	} + +	// if everything on the form is nil, then nothing has been set and we shouldn't continue +	if form.SiteTitle == nil && form.SiteContactUsername == nil && form.SiteContactEmail == nil && form.SiteShortDescription == nil && form.SiteDescription == nil && form.SiteTerms == nil && form.Avatar == nil && form.Header == nil { +		l.Debugf("could not parse form from request") +		c.JSON(http.StatusBadRequest, gin.H{"error": "empty form submitted"}) +		return +	} + +	i, errWithCode := m.processor.InstancePatch(form) +	if errWithCode != nil { +		l.Debugf("error with instance patch request: %s", errWithCode.Error()) +		c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) +		return +	} + +	c.JSON(http.StatusOK, i) +} diff --git a/internal/api/model/instance.go b/internal/api/model/instance.go index e4dad3559..834c6fb55 100644 --- a/internal/api/model/instance.go +++ b/internal/api/model/instance.go @@ -18,6 +18,8 @@  package model +import "mime/multipart" +  // Instance represents the software instance of Mastodon running on this domain. See https://docs.joinmastodon.org/entities/instance/  type Instance struct {  	// REQUIRED @@ -45,7 +47,7 @@ type Instance struct {  	// URLs of interest for clients apps.  	URLS *InstanceURLs `json:"urls,omitempty"`  	// Statistics about how much information the instance contains. -	Stats *InstanceStats `json:"stats,omitempty"` +	Stats map[string]int `json:"stats,omitempty"`  	// Banner image for the website.  	Thumbnail string `json:"thumbnail"`  	// A user that can be contacted, as an alternative to email. @@ -70,3 +72,15 @@ type InstanceStats struct {  	// Domains federated with this instance.  	DomainCount int `json:"domain_count"`  } + +// InstanceSettingsUpdateRequest is the form to be parsed on a PATCH to /api/v1/instance +type InstanceSettingsUpdateRequest struct { +	SiteTitle            *string               `form:"site_title" json:"site_title" xml:"site_title"` +	SiteContactUsername  *string               `form:"site_contact_username" json:"site_contact_username" xml:"site_contact_username"` +	SiteContactEmail     *string               `form:"site_contact_email" json:"site_contact_email" xml:"site_contact_email"` +	SiteShortDescription *string               `form:"site_short_description" json:"site_short_description" xml:"site_short_description"` +	SiteDescription      *string               `form:"site_description" json:"site_description" xml:"site_description"` +	SiteTerms            *string               `form:"site_terms" json:"site_terms" xml:"site_terms"` +	Avatar               *multipart.FileHeader `form:"avatar" json:"avatar" xml:"avatar"` +	Header               *multipart.FileHeader `form:"header" json:"header" xml:"header"` +} diff --git a/internal/db/db.go b/internal/db/db.go index 4e21358c3..204f04c71 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -253,6 +253,14 @@ type DB interface {  	// GetNotificationsForAccount returns a list of notifications that pertain to the given accountID.  	GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) +	// GetUserCountForInstance returns the number of known accounts registered with the given domain. +	GetUserCountForInstance(domain string) (int, error) + +	// GetStatusCountForInstance returns the number of known statuses posted from the given domain. +	GetStatusCountForInstance(domain string) (int, error) + +	// GetDomainCountForInstance returns the number of known instances known that the given domain federates with. +	GetDomainCountForInstance(domain string) (int, error)  	/*  		USEFUL CONVERSION FUNCTIONS  	*/ diff --git a/internal/db/pg/instancestats.go b/internal/db/pg/instancestats.go new file mode 100644 index 000000000..b57591d7b --- /dev/null +++ b/internal/db/pg/instancestats.go @@ -0,0 +1,52 @@ +package pg + +import ( +	"github.com/go-pg/pg/v10" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) { +	q := ps.conn.Model(&[]*gtsmodel.Account{}) + +	if domain == ps.config.Host { +		// if the domain is *this* domain, just count where the domain field is null +		q = q.Where("? IS NULL", pg.Ident("domain")) +	} else { +		q = q.Where("domain = ?", domain) +	} + +	// don't count the instance account or suspended users +	q = q.Where("username != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) + +	return q.Count() +} + +func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) { +	q := ps.conn.Model(&[]*gtsmodel.Status{}) + +	if domain == ps.config.Host { +		// if the domain is *this* domain, just count where local is true +		q = q.Where("local = ?", true) +	} else { +		// join on the domain of the account +		q = q.Join("JOIN accounts AS account ON account.id = status.account_id"). +			Where("account.domain = ?", domain) +	} + +	return q.Count() +} + +func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) { +	q := ps.conn.Model(&[]*gtsmodel.Instance{}) + +	if domain == ps.config.Host { +		// if the domain is *this* domain, just count other instances it knows about +		// TODO: exclude domains that are blocked or silenced +		q = q.Where("domain != ?", domain) +	} else { +		// TODO: implement federated domain counting properly for remote domains +		return 0, nil +	} + +	return q.Count() +} diff --git a/internal/gtsmodel/instance.go b/internal/gtsmodel/instance.go index 8b97ea2ae..f39231319 100644 --- a/internal/gtsmodel/instance.go +++ b/internal/gtsmodel/instance.go @@ -24,6 +24,8 @@ type Instance struct {  	ShortDescription string  	// Longer description of this instance  	Description string +	// Terms and conditions of this instance +	Terms string  	// Contact email address for this instance  	ContactEmail string  	// Contact account ID in the database for this instance diff --git a/internal/processing/instance.go b/internal/processing/instance.go index 9381a7315..a9b2fbd96 100644 --- a/internal/processing/instance.go +++ b/internal/processing/instance.go @@ -25,6 +25,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/db"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/util"  )  func (p *processor) InstanceGet(domain string) (*apimodel.Instance, gtserror.WithCode) { @@ -40,3 +41,118 @@ func (p *processor) InstanceGet(domain string) (*apimodel.Instance, gtserror.Wit  	return ai, nil  } + +func (p *processor) InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode) { +	// fetch the instance entry from the db for processing +	i := >smodel.Instance{} +	if err := p.db.GetWhere([]db.Where{{Key: "domain", Value: p.config.Host}}, i); err != nil { +		return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", p.config.Host, err)) +	} + +	// fetch the instance account from the db for processing +	ia := >smodel.Account{} +	if err := p.db.GetLocalAccountByUsername(p.config.Host, ia); err != nil { +		return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance account %s: %s", p.config.Host, err)) +	} + +	// validate & update site title if it's set on the form +	if form.SiteTitle != nil { +		if err := util.ValidateSiteTitle(*form.SiteTitle); err != nil { +			return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("site title invalid: %s", err)) +		} +		i.Title = *form.SiteTitle +	} + +	// validate & update site contact account if it's set on the form +	if form.SiteContactUsername != nil { +		// make sure the account with the given username exists in the db +		contactAccount := >smodel.Account{} +		if err := p.db.GetLocalAccountByUsername(*form.SiteContactUsername, contactAccount); err != nil { +			return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.SiteContactUsername)) +		} +		// make sure it has a user associated with it +		contactUser := >smodel.User{} +		if err := p.db.GetWhere([]db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil { +			return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.SiteContactUsername)) +		} +		// suspended accounts cannot be contact accounts +		if !contactAccount.SuspendedAt.IsZero() { +			err := fmt.Errorf("selected contact account %s is suspended", contactAccount.Username) +			return nil, gtserror.NewErrorBadRequest(err, err.Error()) +		} +		// unconfirmed or unapproved users cannot be contacts +		if contactUser.ConfirmedAt.IsZero() { +			err := fmt.Errorf("user of selected contact account %s is not confirmed", contactAccount.Username) +			return nil, gtserror.NewErrorBadRequest(err, err.Error()) +		} +		if !contactUser.Approved { +			err := fmt.Errorf("user of selected contact account %s is not approved", contactAccount.Username) +			return nil, gtserror.NewErrorBadRequest(err, err.Error()) +		} +		// contact account user must be admin or moderator otherwise what's the point of contacting them +		if !contactUser.Admin && !contactUser.Moderator { +			err := fmt.Errorf("user of selected contact account %s is neither admin nor moderator", contactAccount.Username) +			return nil, gtserror.NewErrorBadRequest(err, err.Error()) +		} +		i.ContactAccountID = contactAccount.ID +	} + +	// validate & update site contact email if it's set on the form +	if form.SiteContactEmail != nil { +		if err := util.ValidateEmail(*form.SiteContactEmail); err != nil { +			return nil, gtserror.NewErrorBadRequest(err, err.Error()) +		} +		i.ContactEmail = *form.SiteContactEmail +	} + +	// validate & update site short description if it's set on the form +	if form.SiteShortDescription != nil { +		if err := util.ValidateSiteShortDescription(*form.SiteShortDescription); err != nil { +			return nil, gtserror.NewErrorBadRequest(err, err.Error()) +		} +		i.ShortDescription = *form.SiteShortDescription +	} + +	// validate & update site description if it's set on the form +	if form.SiteDescription != nil { +		if err := util.ValidateSiteDescription(*form.SiteDescription); err != nil { +			return nil, gtserror.NewErrorBadRequest(err, err.Error()) +		} +		i.Description = *form.SiteDescription +	} + +	// validate & update site terms if it's set on the form +	if form.SiteTerms != nil { +		if err := util.ValidateSiteTerms(*form.SiteTerms); err != nil { +			return nil, gtserror.NewErrorBadRequest(err, err.Error()) +		} +		i.Terms = *form.SiteTerms +	} + +	// process avatar if provided +	if form.Avatar != nil && form.Avatar.Size != 0 { +		_, err := p.updateAccountAvatar(form.Avatar, ia.ID) +		if err != nil { +			return nil, gtserror.NewErrorBadRequest(err, "error processing avatar") +		} +	} + +	// process header if provided +	if form.Header != nil && form.Header.Size != 0 { +		_, err := p.updateAccountHeader(form.Header, ia.ID) +		if err != nil { +			return nil, gtserror.NewErrorBadRequest(err, "error processing header") +		} +	} + +	if err := p.db.UpdateByID(i.ID, i); err != nil { +		return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", p.config.Host, err)) +	} + +	ai, err := p.tc.InstanceToMasto(i) +	if err != nil { +		return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting instance to api representation: %s", err)) +	} + +	return ai, nil +} diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 618fd641b..2cfa6e4e3 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -95,6 +95,10 @@ type Processor interface {  	// InstanceGet retrieves instance information for serving at api/v1/instance  	InstanceGet(domain string) (*apimodel.Instance, gtserror.WithCode) +	// InstancePatch updates this instance according to the given form. +	// +	// It should already be ascertained that the requesting account is authenticated and an admin. +	InstancePatch(form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.Instance, gtserror.WithCode)  	// MediaCreate handles the creation of a media attachment, using the given form.  	MediaCreate(authed *oauth.Auth, form *apimodel.AttachmentRequest) (*apimodel.Attachment, error) diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index 90460ecdd..a5984e068 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -511,9 +511,31 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro  		ShortDescription: i.ShortDescription,  		Email:            i.ContactEmail,  		Version:          i.Version, +		Stats:            make(map[string]int), +		ContactAccount:   &model.Account{},  	} +	// if the requested instance is *this* instance, we can add some extra information  	if i.Domain == c.config.Host { +		userCountKey := "user_count" +		statusCountKey := "status_count" +		domainCountKey := "domain_count" + +		userCount, err := c.db.GetUserCountForInstance(c.config.Host) +		if err == nil { +			mi.Stats[userCountKey] = userCount +		} + +		statusCount, err := c.db.GetStatusCountForInstance(c.config.Host) +		if err == nil { +			mi.Stats[statusCountKey] = statusCount +		} + +		domainCount, err := c.db.GetDomainCountForInstance(c.config.Host) +		if err == nil { +			mi.Stats[domainCountKey] = domainCount +		} +  		mi.Registrations = c.config.AccountsConfig.OpenRegistration  		mi.ApprovalRequired = c.config.AccountsConfig.RequireApproval  		mi.InvitesEnabled = false // TODO @@ -523,6 +545,17 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro  		}  	} +	// get the instance account if it exists and just skip if it doesn't +	ia := >smodel.Account{} +	if err := c.db.GetWhere([]db.Where{{Key: "username", Value: i.Domain}}, ia); err == nil { +		// instance account exists, get the header for the account if it exists +		attachment := >smodel.MediaAttachment{} +		if err := c.db.GetHeaderForAccountID(attachment, ia.ID); err == nil { +			// header exists, set it on the api model +			mi.Thumbnail = attachment.URL +		} +	} +  	// contact account is optional but let's try to get it  	if i.ContactAccountID != "" {  		ia := >smodel.Account{} diff --git a/internal/util/regexes.go b/internal/util/regexes.go index 586eb30df..6ad7b7404 100644 --- a/internal/util/regexes.go +++ b/internal/util/regexes.go @@ -24,12 +24,7 @@ import (  )  const ( -	minimumPasswordEntropy      = 60 // dictates password strength. See https://github.com/wagslane/go-password-validator -	minimumReasonLength         = 40 -	maximumReasonLength         = 500 -	maximumEmailLength          = 256  	maximumUsernameLength       = 64 -	maximumPasswordLength       = 64  	maximumEmojiShortcodeLength = 30  	maximumHashtagLength        = 30  ) diff --git a/internal/util/validation.go b/internal/util/validation.go index d392231bb..446f7a70e 100644 --- a/internal/util/validation.go +++ b/internal/util/validation.go @@ -27,6 +27,17 @@ import (  	"golang.org/x/text/language"  ) +const ( +	maximumPasswordLength         = 64 +	minimumPasswordEntropy        = 60 // dictates password strength. See https://github.com/wagslane/go-password-validator +	minimumReasonLength           = 40 +	maximumReasonLength           = 500 +	maximumSiteTitleLength        = 40 +	maximumShortDescriptionLength = 500 +	maximumDescriptionLength      = 5000 +	maximumSiteTermsLength        = 5000 +) +  // ValidateNewPassword returns an error if the given password is not sufficiently strong, or nil if it's ok.  func ValidateNewPassword(password string) error {  	if password == "" { @@ -47,12 +58,8 @@ func ValidateUsername(username string) error {  		return errors.New("no username provided")  	} -	if len(username) > maximumUsernameLength { -		return fmt.Errorf("username should be no more than %d chars but '%s' was %d", maximumUsernameLength, username, len(username)) -	} -  	if !usernameValidationRegex.MatchString(username) { -		return fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", username) +		return fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores, max %d characters", username, maximumUsernameLength)  	}  	return nil @@ -65,10 +72,6 @@ func ValidateEmail(email string) error {  		return errors.New("no email provided")  	} -	if len(email) > maximumEmailLength { -		return fmt.Errorf("email address should be no more than %d chars but '%s' was %d", maximumEmailLength, email, len(email)) -	} -  	_, err := mail.ParseAddress(email)  	return err  } @@ -132,3 +135,39 @@ func ValidateEmojiShortcode(shortcode string) error {  	}  	return nil  } + +// ValidateSiteTitle ensures that the given site title is within spec. +func ValidateSiteTitle(siteTitle string) error { +	if len(siteTitle) > maximumSiteTitleLength { +		return fmt.Errorf("site title should be no more than %d chars but given title was %d", maximumSiteTitleLength, len(siteTitle)) +	} + +	return nil +} + +// ValidateSiteShortDescription ensures that the given site short description is within spec. +func ValidateSiteShortDescription(d string) error { +	if len(d) > maximumShortDescriptionLength { +		return fmt.Errorf("short description should be no more than %d chars but given description was %d", maximumShortDescriptionLength, len(d)) +	} + +	return nil +} + +// ValidateSiteDescription ensures that the given site description is within spec. +func ValidateSiteDescription(d string) error { +	if len(d) > maximumDescriptionLength { +		return fmt.Errorf("description should be no more than %d chars but given description was %d", maximumDescriptionLength, len(d)) +	} + +	return nil +} + +// ValidateSiteTerms ensures that the given site terms string is within spec. +func ValidateSiteTerms(t string) error { +	if len(t) > maximumSiteTermsLength { +		return fmt.Errorf("terms should be no more than %d chars but given terms was %d", maximumSiteTermsLength, len(t)) +	} + +	return nil +}  | 
