summaryrefslogtreecommitdiff
path: root/internal/subscriptions
diff options
context:
space:
mode:
Diffstat (limited to 'internal/subscriptions')
-rw-r--r--internal/subscriptions/domainperms.go39
-rw-r--r--internal/subscriptions/subscriptions_test.go4
2 files changed, 25 insertions, 18 deletions
diff --git a/internal/subscriptions/domainperms.go b/internal/subscriptions/domainperms.go
index c9f569f94..8da9064f6 100644
--- a/internal/subscriptions/domainperms.go
+++ b/internal/subscriptions/domainperms.go
@@ -438,7 +438,7 @@ func (s *Subscriptions) processDomainPermission(
Obfuscate: wantedPerm.GetObfuscate(),
SubscriptionID: permSub.ID,
}
- insertF = func() error { return s.state.DB.CreateDomainBlock(ctx, domainBlock) }
+ insertF = func() error { return s.state.DB.PutDomainBlock(ctx, domainBlock) }
action = &gtsmodel.AdminAction{
ID: id.NewULID(),
@@ -461,7 +461,7 @@ func (s *Subscriptions) processDomainPermission(
Obfuscate: wantedPerm.GetObfuscate(),
SubscriptionID: permSub.ID,
}
- insertF = func() error { return s.state.DB.CreateDomainAllow(ctx, domainAllow) }
+ insertF = func() error { return s.state.DB.PutDomainAllow(ctx, domainAllow) }
action = &gtsmodel.AdminAction{
ID: id.NewULID(),
@@ -564,13 +564,13 @@ func permsFromCSV(
for i, columnHeader := range columnHeaders {
// Remove leading # if present.
- normal := strings.TrimLeft(columnHeader, "#")
+ columnHeader = strings.TrimLeft(columnHeader, "#")
// Find index of each column header we
// care about, ensuring no duplicates.
- switch normal {
+ switch {
- case "domain":
+ case columnHeader == "domain":
if domainI != nil {
body.Close()
err := gtserror.NewfAt(3, "duplicate domain column header in csv: %+v", columnHeaders)
@@ -578,7 +578,7 @@ func permsFromCSV(
}
domainI = &i
- case "severity":
+ case columnHeader == "severity":
if severityI != nil {
body.Close()
err := gtserror.NewfAt(3, "duplicate severity column header in csv: %+v", columnHeaders)
@@ -586,15 +586,15 @@ func permsFromCSV(
}
severityI = &i
- case "public_comment":
+ case columnHeader == "public_comment" || columnHeader == "comment":
if publicCommentI != nil {
body.Close()
- err := gtserror.NewfAt(3, "duplicate public_comment column header in csv: %+v", columnHeaders)
+ err := gtserror.NewfAt(3, "duplicate public_comment or comment column header in csv: %+v", columnHeaders)
return nil, err
}
publicCommentI = &i
- case "obfuscate":
+ case columnHeader == "obfuscate":
if obfuscateI != nil {
body.Close()
err := gtserror.NewfAt(3, "duplicate obfuscate column header in csv: %+v", columnHeaders)
@@ -674,15 +674,15 @@ func permsFromCSV(
perm.SetPublicComment(record[*publicCommentI])
}
+ var obfuscate bool
if obfuscateI != nil {
- obfuscate, err := strconv.ParseBool(record[*obfuscateI])
+ obfuscate, err = strconv.ParseBool(record[*obfuscateI])
if err != nil {
l.Warnf("couldn't parse obfuscate field of record: %+v", record)
continue
}
-
- perm.SetObfuscate(&obfuscate)
}
+ perm.SetObfuscate(&obfuscate)
// We're done.
perms = append(perms, perm)
@@ -742,8 +742,9 @@ func permsFromJSON(
}
// Set remaining fields.
- perm.SetPublicComment(apiPerm.PublicComment)
- perm.SetObfuscate(&apiPerm.Obfuscate)
+ publicComment := cmp.Or(apiPerm.PublicComment, apiPerm.Comment)
+ perm.SetPublicComment(util.PtrOrZero(publicComment))
+ perm.SetObfuscate(util.Ptr(util.PtrOrZero(apiPerm.Obfuscate)))
// We're done.
perms = append(perms, perm)
@@ -792,9 +793,15 @@ func permsFromPlain(
var perm gtsmodel.DomainPermission
switch permType {
case gtsmodel.DomainPermissionBlock:
- perm = &gtsmodel.DomainBlock{Domain: domain}
+ perm = &gtsmodel.DomainBlock{
+ Domain: domain,
+ Obfuscate: util.Ptr(false),
+ }
case gtsmodel.DomainPermissionAllow:
- perm = &gtsmodel.DomainAllow{Domain: domain}
+ perm = &gtsmodel.DomainAllow{
+ Domain: domain,
+ Obfuscate: util.Ptr(false),
+ }
}
// We're done.
diff --git a/internal/subscriptions/subscriptions_test.go b/internal/subscriptions/subscriptions_test.go
index 133db4b7c..4441d8c15 100644
--- a/internal/subscriptions/subscriptions_test.go
+++ b/internal/subscriptions/subscriptions_test.go
@@ -775,7 +775,7 @@ func (suite *SubscriptionsTestSuite) TestAdoption() {
existingBlock2,
existingBlock3,
} {
- if err := testStructs.State.DB.CreateDomainBlock(
+ if err := testStructs.State.DB.PutDomainBlock(
ctx, block,
); err != nil {
suite.FailNow(err.Error())
@@ -876,7 +876,7 @@ func (suite *SubscriptionsTestSuite) TestDomainAllowsAndBlocks() {
}
// Store existing allow.
- if err := testStructs.State.DB.CreateDomainAllow(ctx, existingAllow); err != nil {
+ if err := testStructs.State.DB.PutDomainAllow(ctx, existingAllow); err != nil {
suite.FailNow(err.Error())
}