summaryrefslogtreecommitdiff
path: root/vendor/golang.org/x/tools/internal/imports/fix.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/golang.org/x/tools/internal/imports/fix.go')
-rw-r--r--vendor/golang.org/x/tools/internal/imports/fix.go521
1 files changed, 261 insertions, 260 deletions
diff --git a/vendor/golang.org/x/tools/internal/imports/fix.go b/vendor/golang.org/x/tools/internal/imports/fix.go
index 4569313a0..5ae576977 100644
--- a/vendor/golang.org/x/tools/internal/imports/fix.go
+++ b/vendor/golang.org/x/tools/internal/imports/fix.go
@@ -90,18 +90,6 @@ type ImportFix struct {
Relevance float64 // see pkg
}
-// An ImportInfo represents a single import statement.
-type ImportInfo struct {
- ImportPath string // import path, e.g. "crypto/rand".
- Name string // import name, e.g. "crand", or "" if none.
-}
-
-// A packageInfo represents what's known about a package.
-type packageInfo struct {
- name string // real package name, if known.
- exports map[string]bool // known exports.
-}
-
// parseOtherFiles parses all the Go files in srcDir except filename, including
// test files if filename looks like a test.
//
@@ -130,7 +118,7 @@ func parseOtherFiles(ctx context.Context, fset *token.FileSet, srcDir, filename
continue
}
- f, err := parser.ParseFile(fset, filepath.Join(srcDir, fi.Name()), nil, 0)
+ f, err := parser.ParseFile(fset, filepath.Join(srcDir, fi.Name()), nil, parser.SkipObjectResolution)
if err != nil {
continue
}
@@ -161,8 +149,8 @@ func addGlobals(f *ast.File, globals map[string]bool) {
// collectReferences builds a map of selector expressions, from
// left hand side (X) to a set of right hand sides (Sel).
-func collectReferences(f *ast.File) references {
- refs := references{}
+func collectReferences(f *ast.File) References {
+ refs := References{}
var visitor visitFn
visitor = func(node ast.Node) ast.Visitor {
@@ -232,7 +220,7 @@ func (p *pass) findMissingImport(pkg string, syms map[string]bool) *ImportInfo {
allFound := true
for right := range syms {
- if !pkgInfo.exports[right] {
+ if !pkgInfo.Exports[right] {
allFound = false
break
}
@@ -245,11 +233,6 @@ func (p *pass) findMissingImport(pkg string, syms map[string]bool) *ImportInfo {
return nil
}
-// references is set of references found in a Go file. The first map key is the
-// left hand side of a selector expression, the second key is the right hand
-// side, and the value should always be true.
-type references map[string]map[string]bool
-
// A pass contains all the inputs and state necessary to fix a file's imports.
// It can be modified in some ways during use; see comments below.
type pass struct {
@@ -257,27 +240,29 @@ type pass struct {
fset *token.FileSet // fset used to parse f and its siblings.
f *ast.File // the file being fixed.
srcDir string // the directory containing f.
- env *ProcessEnv // the environment to use for go commands, etc.
- loadRealPackageNames bool // if true, load package names from disk rather than guessing them.
- otherFiles []*ast.File // sibling files.
+ logf func(string, ...any)
+ source Source // the environment to use for go commands, etc.
+ loadRealPackageNames bool // if true, load package names from disk rather than guessing them.
+ otherFiles []*ast.File // sibling files.
+ goroot string
// Intermediate state, generated by load.
existingImports map[string][]*ImportInfo
- allRefs references
- missingRefs references
+ allRefs References
+ missingRefs References
// Inputs to fix. These can be augmented between successive fix calls.
lastTry bool // indicates that this is the last call and fix should clean up as best it can.
candidates []*ImportInfo // candidate imports in priority order.
- knownPackages map[string]*packageInfo // information about all known packages.
+ knownPackages map[string]*PackageInfo // information about all known packages.
}
// loadPackageNames saves the package names for everything referenced by imports.
-func (p *pass) loadPackageNames(imports []*ImportInfo) error {
- if p.env.Logf != nil {
- p.env.Logf("loading package names for %v packages", len(imports))
+func (p *pass) loadPackageNames(ctx context.Context, imports []*ImportInfo) error {
+ if p.logf != nil {
+ p.logf("loading package names for %v packages", len(imports))
defer func() {
- p.env.Logf("done loading package names for %v packages", len(imports))
+ p.logf("done loading package names for %v packages", len(imports))
}()
}
var unknown []string
@@ -288,20 +273,17 @@ func (p *pass) loadPackageNames(imports []*ImportInfo) error {
unknown = append(unknown, imp.ImportPath)
}
- resolver, err := p.env.GetResolver()
- if err != nil {
- return err
- }
-
- names, err := resolver.loadPackageNames(unknown, p.srcDir)
+ names, err := p.source.LoadPackageNames(ctx, p.srcDir, unknown)
if err != nil {
return err
}
+ // TODO(rfindley): revisit this. Why do we need to store known packages with
+ // no exports? The inconsistent data is confusing.
for path, name := range names {
- p.knownPackages[path] = &packageInfo{
- name: name,
- exports: map[string]bool{},
+ p.knownPackages[path] = &PackageInfo{
+ Name: name,
+ Exports: map[string]bool{},
}
}
return nil
@@ -329,8 +311,8 @@ func (p *pass) importIdentifier(imp *ImportInfo) string {
return imp.Name
}
known := p.knownPackages[imp.ImportPath]
- if known != nil && known.name != "" {
- return withoutVersion(known.name)
+ if known != nil && known.Name != "" {
+ return withoutVersion(known.Name)
}
return ImportPathToAssumedName(imp.ImportPath)
}
@@ -338,9 +320,9 @@ func (p *pass) importIdentifier(imp *ImportInfo) string {
// load reads in everything necessary to run a pass, and reports whether the
// file already has all the imports it needs. It fills in p.missingRefs with the
// file's missing symbols, if any, or removes unused imports if not.
-func (p *pass) load() ([]*ImportFix, bool) {
- p.knownPackages = map[string]*packageInfo{}
- p.missingRefs = references{}
+func (p *pass) load(ctx context.Context) ([]*ImportFix, bool) {
+ p.knownPackages = map[string]*PackageInfo{}
+ p.missingRefs = References{}
p.existingImports = map[string][]*ImportInfo{}
// Load basic information about the file in question.
@@ -363,10 +345,10 @@ func (p *pass) load() ([]*ImportFix, bool) {
// f's imports by the identifier they introduce.
imports := collectImports(p.f)
if p.loadRealPackageNames {
- err := p.loadPackageNames(append(imports, p.candidates...))
+ err := p.loadPackageNames(ctx, append(imports, p.candidates...))
if err != nil {
- if p.env.Logf != nil {
- p.env.Logf("loading package names: %v", err)
+ if p.logf != nil {
+ p.logf("loading package names: %v", err)
}
return nil, false
}
@@ -536,9 +518,10 @@ func (p *pass) assumeSiblingImportsValid() {
// We have the stdlib in memory; no need to guess.
rights = symbolNameSet(m)
}
- p.addCandidate(imp, &packageInfo{
+ // TODO(rfindley): we should set package name here, for consistency.
+ p.addCandidate(imp, &PackageInfo{
// no name; we already know it.
- exports: rights,
+ Exports: rights,
})
}
}
@@ -547,14 +530,14 @@ func (p *pass) assumeSiblingImportsValid() {
// addCandidate adds a candidate import to p, and merges in the information
// in pkg.
-func (p *pass) addCandidate(imp *ImportInfo, pkg *packageInfo) {
+func (p *pass) addCandidate(imp *ImportInfo, pkg *PackageInfo) {
p.candidates = append(p.candidates, imp)
if existing, ok := p.knownPackages[imp.ImportPath]; ok {
- if existing.name == "" {
- existing.name = pkg.name
+ if existing.Name == "" {
+ existing.Name = pkg.Name
}
- for export := range pkg.exports {
- existing.exports[export] = true
+ for export := range pkg.Exports {
+ existing.Exports[export] = true
}
} else {
p.knownPackages[imp.ImportPath] = pkg
@@ -563,7 +546,14 @@ func (p *pass) addCandidate(imp *ImportInfo, pkg *packageInfo) {
// fixImports adds and removes imports from f so that all its references are
// satisfied and there are no unused imports.
-func fixImports(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) error {
+//
+// This is declared as a variable rather than a function so goimports can
+// easily be extended by adding a file with an init function.
+//
+// DO NOT REMOVE: used internally at Google.
+var fixImports = fixImportsDefault
+
+func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) error {
fixes, err := getFixes(context.Background(), fset, f, filename, env)
if err != nil {
return err
@@ -575,21 +565,42 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string, env *ProcessE
// getFixes gets the import fixes that need to be made to f in order to fix the imports.
// It does not modify the ast.
func getFixes(ctx context.Context, fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) ([]*ImportFix, error) {
+ source, err := NewProcessEnvSource(env, filename, f.Name.Name)
+ if err != nil {
+ return nil, err
+ }
+ goEnv, err := env.goEnv()
+ if err != nil {
+ return nil, err
+ }
+ return getFixesWithSource(ctx, fset, f, filename, goEnv["GOROOT"], env.logf, source)
+}
+
+func getFixesWithSource(ctx context.Context, fset *token.FileSet, f *ast.File, filename string, goroot string, logf func(string, ...any), source Source) ([]*ImportFix, error) {
+ // This logic is defensively duplicated from getFixes.
abs, err := filepath.Abs(filename)
if err != nil {
return nil, err
}
srcDir := filepath.Dir(abs)
- if env.Logf != nil {
- env.Logf("fixImports(filename=%q), abs=%q, srcDir=%q ...", filename, abs, srcDir)
+
+ if logf != nil {
+ logf("fixImports(filename=%q), srcDir=%q ...", filename, abs, srcDir)
}
// First pass: looking only at f, and using the naive algorithm to
// derive package names from import paths, see if the file is already
// complete. We can't add any imports yet, because we don't know
// if missing references are actually package vars.
- p := &pass{fset: fset, f: f, srcDir: srcDir, env: env}
- if fixes, done := p.load(); done {
+ p := &pass{
+ fset: fset,
+ f: f,
+ srcDir: srcDir,
+ logf: logf,
+ goroot: goroot,
+ source: source,
+ }
+ if fixes, done := p.load(ctx); done {
return fixes, nil
}
@@ -601,7 +612,7 @@ func getFixes(ctx context.Context, fset *token.FileSet, f *ast.File, filename st
// Second pass: add information from other files in the same package,
// like their package vars and imports.
p.otherFiles = otherFiles
- if fixes, done := p.load(); done {
+ if fixes, done := p.load(ctx); done {
return fixes, nil
}
@@ -614,10 +625,17 @@ func getFixes(ctx context.Context, fset *token.FileSet, f *ast.File, filename st
// Third pass: get real package names where we had previously used
// the naive algorithm.
- p = &pass{fset: fset, f: f, srcDir: srcDir, env: env}
+ p = &pass{
+ fset: fset,
+ f: f,
+ srcDir: srcDir,
+ logf: logf,
+ goroot: goroot,
+ source: p.source, // safe to reuse, as it's just a wrapper around env
+ }
p.loadRealPackageNames = true
p.otherFiles = otherFiles
- if fixes, done := p.load(); done {
+ if fixes, done := p.load(ctx); done {
return fixes, nil
}
@@ -831,7 +849,7 @@ func GetPackageExports(ctx context.Context, wrapped func(PackageExport), searchP
return true
},
dirFound: func(pkg *pkg) bool {
- return pkgIsCandidate(filename, references{searchPkg: nil}, pkg)
+ return pkgIsCandidate(filename, References{searchPkg: nil}, pkg)
},
packageNameLoaded: func(pkg *pkg) bool {
return pkg.packageName == searchPkg
@@ -1014,16 +1032,26 @@ func (e *ProcessEnv) GetResolver() (Resolver, error) {
// already know the view type.
if len(e.Env["GOMOD"]) == 0 && len(e.Env["GOWORK"]) == 0 {
e.resolver = newGopathResolver(e)
+ e.logf("created gopath resolver")
} else if r, err := newModuleResolver(e, e.ModCache); err != nil {
e.resolverErr = err
+ e.logf("failed to create module resolver: %v", err)
} else {
e.resolver = Resolver(r)
+ e.logf("created module resolver")
}
}
return e.resolver, e.resolverErr
}
+// logf logs if e.Logf is non-nil.
+func (e *ProcessEnv) logf(format string, args ...any) {
+ if e.Logf != nil {
+ e.Logf(format, args...)
+ }
+}
+
// buildContext returns the build.Context to use for matching files.
//
// TODO(rfindley): support dynamic GOOS, GOARCH here, when doing cross-platform
@@ -1072,11 +1100,7 @@ func (e *ProcessEnv) invokeGo(ctx context.Context, verb string, args ...string)
return e.GocmdRunner.Run(ctx, inv)
}
-func addStdlibCandidates(pass *pass, refs references) error {
- goenv, err := pass.env.goEnv()
- if err != nil {
- return err
- }
+func addStdlibCandidates(pass *pass, refs References) error {
localbase := func(nm string) string {
ans := path.Base(nm)
if ans[0] == 'v' {
@@ -1091,13 +1115,13 @@ func addStdlibCandidates(pass *pass, refs references) error {
}
add := func(pkg string) {
// Prevent self-imports.
- if path.Base(pkg) == pass.f.Name.Name && filepath.Join(goenv["GOROOT"], "src", pkg) == pass.srcDir {
+ if path.Base(pkg) == pass.f.Name.Name && filepath.Join(pass.goroot, "src", pkg) == pass.srcDir {
return
}
exports := symbolNameSet(stdlib.PackageSymbols[pkg])
pass.addCandidate(
&ImportInfo{ImportPath: pkg},
- &packageInfo{name: localbase(pkg), exports: exports})
+ &PackageInfo{Name: localbase(pkg), Exports: exports})
}
for left := range refs {
if left == "rand" {
@@ -1127,8 +1151,8 @@ type Resolver interface {
// scan works with callback to search for packages. See scanCallback for details.
scan(ctx context.Context, callback *scanCallback) error
- // loadExports returns the set of exported symbols in the package at dir.
- // loadExports may be called concurrently.
+ // loadExports returns the package name and set of exported symbols in the
+ // package at dir. loadExports may be called concurrently.
loadExports(ctx context.Context, pkg *pkg, includeTest bool) (string, []stdlib.Symbol, error)
// scoreImportPath returns the relevance for an import path.
@@ -1161,101 +1185,22 @@ type scanCallback struct {
exportsLoaded func(pkg *pkg, exports []stdlib.Symbol)
}
-func addExternalCandidates(ctx context.Context, pass *pass, refs references, filename string) error {
+func addExternalCandidates(ctx context.Context, pass *pass, refs References, filename string) error {
ctx, done := event.Start(ctx, "imports.addExternalCandidates")
defer done()
- var mu sync.Mutex
- found := make(map[string][]pkgDistance)
- callback := &scanCallback{
- rootFound: func(gopathwalk.Root) bool {
- return true // We want everything.
- },
- dirFound: func(pkg *pkg) bool {
- return pkgIsCandidate(filename, refs, pkg)
- },
- packageNameLoaded: func(pkg *pkg) bool {
- if _, want := refs[pkg.packageName]; !want {
- return false
- }
- if pkg.dir == pass.srcDir && pass.f.Name.Name == pkg.packageName {
- // The candidate is in the same directory and has the
- // same package name. Don't try to import ourselves.
- return false
- }
- if !canUse(filename, pkg.dir) {
- return false
- }
- mu.Lock()
- defer mu.Unlock()
- found[pkg.packageName] = append(found[pkg.packageName], pkgDistance{pkg, distance(pass.srcDir, pkg.dir)})
- return false // We'll do our own loading after we sort.
- },
- }
- resolver, err := pass.env.GetResolver()
+ results, err := pass.source.ResolveReferences(ctx, filename, refs)
if err != nil {
return err
}
- if err = resolver.scan(ctx, callback); err != nil {
- return err
- }
-
- // Search for imports matching potential package references.
- type result struct {
- imp *ImportInfo
- pkg *packageInfo
- }
- results := make(chan result, len(refs))
-
- ctx, cancel := context.WithCancel(ctx)
- var wg sync.WaitGroup
- defer func() {
- cancel()
- wg.Wait()
- }()
- var (
- firstErr error
- firstErrOnce sync.Once
- )
- for pkgName, symbols := range refs {
- wg.Add(1)
- go func(pkgName string, symbols map[string]bool) {
- defer wg.Done()
-
- found, err := findImport(ctx, pass, found[pkgName], pkgName, symbols)
-
- if err != nil {
- firstErrOnce.Do(func() {
- firstErr = err
- cancel()
- })
- return
- }
-
- if found == nil {
- return // No matching package.
- }
-
- imp := &ImportInfo{
- ImportPath: found.importPathShort,
- }
-
- pkg := &packageInfo{
- name: pkgName,
- exports: symbols,
- }
- results <- result{imp, pkg}
- }(pkgName, symbols)
- }
- go func() {
- wg.Wait()
- close(results)
- }()
- for result := range results {
+ for _, result := range results {
+ if result == nil {
+ continue
+ }
// Don't offer completions that would shadow predeclared
// names, such as github.com/coreos/etcd/error.
- if types.Universe.Lookup(result.pkg.name) != nil { // predeclared
+ if types.Universe.Lookup(result.Package.Name) != nil { // predeclared
// Ideally we would skip this candidate only
// if the predeclared name is actually
// referenced by the file, but that's a lot
@@ -1264,9 +1209,9 @@ func addExternalCandidates(ctx context.Context, pass *pass, refs references, fil
// user before long.
continue
}
- pass.addCandidate(result.imp, result.pkg)
+ pass.addCandidate(result.Import, result.Package)
}
- return firstErr
+ return nil
}
// notIdentifier reports whether ch is an invalid identifier character.
@@ -1608,11 +1553,10 @@ func loadExportsFromFiles(ctx context.Context, env *ProcessEnv, dir string, incl
}
fullFile := filepath.Join(dir, fi.Name())
+ // Legacy ast.Object resolution is needed here.
f, err := parser.ParseFile(fset, fullFile, nil, 0)
if err != nil {
- if env.Logf != nil {
- env.Logf("error parsing %v: %v", fullFile, err)
- }
+ env.logf("error parsing %v: %v", fullFile, err)
continue
}
if f.Name.Name == "documentation" {
@@ -1648,9 +1592,7 @@ func loadExportsFromFiles(ctx context.Context, env *ProcessEnv, dir string, incl
}
sortSymbols(exports)
- if env.Logf != nil {
- env.Logf("loaded exports in dir %v (package %v): %v", dir, pkgName, exports)
- }
+ env.logf("loaded exports in dir %v (package %v): %v", dir, pkgName, exports)
return pkgName, exports, nil
}
@@ -1660,25 +1602,39 @@ func sortSymbols(syms []stdlib.Symbol) {
})
}
-// findImport searches for a package with the given symbols.
-// If no package is found, findImport returns ("", false, nil)
-func findImport(ctx context.Context, pass *pass, candidates []pkgDistance, pkgName string, symbols map[string]bool) (*pkg, error) {
+// A symbolSearcher searches for a package with a set of symbols, among a set
+// of candidates. See [symbolSearcher.search].
+//
+// The search occurs within the scope of a single file, with context captured
+// in srcDir and xtest.
+type symbolSearcher struct {
+ logf func(string, ...any)
+ srcDir string // directory containing the file
+ xtest bool // if set, the file containing is an x_test file
+ loadExports func(ctx context.Context, pkg *pkg, includeTest bool) (string, []stdlib.Symbol, error)
+}
+
+// search searches the provided candidates for a package containing all
+// exported symbols.
+//
+// If successful, returns the resulting package.
+func (s *symbolSearcher) search(ctx context.Context, candidates []pkgDistance, pkgName string, symbols map[string]bool) (*pkg, error) {
// Sort the candidates by their import package length,
// assuming that shorter package names are better than long
// ones. Note that this sorts by the de-vendored name, so
// there's no "penalty" for vendoring.
sort.Sort(byDistanceOrImportPathShortLength(candidates))
- if pass.env.Logf != nil {
+ if s.logf != nil {
for i, c := range candidates {
- pass.env.Logf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir)
+ s.logf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir)
}
}
- resolver, err := pass.env.GetResolver()
- if err != nil {
- return nil, err
- }
- // Collect exports for packages with matching names.
+ // Arrange rescv so that we can we can await results in order of relevance
+ // and exit as soon as we find the first match.
+ //
+ // Search with bounded concurrency, returning as soon as the first result
+ // among rescv is non-nil.
rescv := make([]chan *pkg, len(candidates))
for i := range candidates {
rescv[i] = make(chan *pkg, 1)
@@ -1686,6 +1642,7 @@ func findImport(ctx context.Context, pass *pass, candidates []pkgDistance, pkgNa
const maxConcurrentPackageImport = 4
loadExportsSem := make(chan struct{}, maxConcurrentPackageImport)
+ // Ensure that all work is completed at exit.
ctx, cancel := context.WithCancel(ctx)
var wg sync.WaitGroup
defer func() {
@@ -1693,6 +1650,7 @@ func findImport(ctx context.Context, pass *pass, candidates []pkgDistance, pkgNa
wg.Wait()
}()
+ // Start the search.
wg.Add(1)
go func() {
defer wg.Done()
@@ -1703,55 +1661,67 @@ func findImport(ctx context.Context, pass *pass, candidates []pkgDistance, pkgNa
return
}
+ i := i
+ c := c
wg.Add(1)
- go func(c pkgDistance, resc chan<- *pkg) {
+ go func() {
defer func() {
<-loadExportsSem
wg.Done()
}()
-
- if pass.env.Logf != nil {
- pass.env.Logf("loading exports in dir %s (seeking package %s)", c.pkg.dir, pkgName)
+ if s.logf != nil {
+ s.logf("loading exports in dir %s (seeking package %s)", c.pkg.dir, pkgName)
}
- // If we're an x_test, load the package under test's test variant.
- includeTest := strings.HasSuffix(pass.f.Name.Name, "_test") && c.pkg.dir == pass.srcDir
- _, exports, err := resolver.loadExports(ctx, c.pkg, includeTest)
+ pkg, err := s.searchOne(ctx, c, symbols)
if err != nil {
- if pass.env.Logf != nil {
- pass.env.Logf("loading exports in dir %s (seeking package %s): %v", c.pkg.dir, pkgName, err)
- }
- resc <- nil
- return
- }
-
- exportsMap := make(map[string]bool, len(exports))
- for _, sym := range exports {
- exportsMap[sym.Name] = true
- }
-
- // If it doesn't have the right
- // symbols, send nil to mean no match.
- for symbol := range symbols {
- if !exportsMap[symbol] {
- resc <- nil
- return
+ if s.logf != nil && ctx.Err() == nil {
+ s.logf("loading exports in dir %s (seeking package %s): %v", c.pkg.dir, pkgName, err)
}
+ pkg = nil
}
- resc <- c.pkg
- }(c, rescv[i])
+ rescv[i] <- pkg // may be nil
+ }()
}
}()
+ // Await the first (best) result.
for _, resc := range rescv {
- pkg := <-resc
- if pkg == nil {
- continue
+ select {
+ case r := <-resc:
+ if r != nil {
+ return r, nil
+ }
+ case <-ctx.Done():
+ return nil, ctx.Err()
}
- return pkg, nil
}
return nil, nil
}
+func (s *symbolSearcher) searchOne(ctx context.Context, c pkgDistance, symbols map[string]bool) (*pkg, error) {
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+ // If we're considering the package under test from an x_test, load the
+ // test variant.
+ includeTest := s.xtest && c.pkg.dir == s.srcDir
+ _, exports, err := s.loadExports(ctx, c.pkg, includeTest)
+ if err != nil {
+ return nil, err
+ }
+
+ exportsMap := make(map[string]bool, len(exports))
+ for _, sym := range exports {
+ exportsMap[sym.Name] = true
+ }
+ for symbol := range symbols {
+ if !exportsMap[symbol] {
+ return nil, nil // no match
+ }
+ }
+ return c.pkg, nil
+}
+
// pkgIsCandidate reports whether pkg is a candidate for satisfying the
// finding which package pkgIdent in the file named by filename is trying
// to refer to.
@@ -1764,65 +1734,31 @@ func findImport(ctx context.Context, pass *pass, candidates []pkgDistance, pkgNa
// filename is the file being formatted.
// pkgIdent is the package being searched for, like "client" (if
// searching for "client.New")
-func pkgIsCandidate(filename string, refs references, pkg *pkg) bool {
+func pkgIsCandidate(filename string, refs References, pkg *pkg) bool {
// Check "internal" and "vendor" visibility:
if !canUse(filename, pkg.dir) {
return false
}
// Speed optimization to minimize disk I/O:
- // the last two components on disk must contain the
- // package name somewhere.
//
- // This permits mismatch naming like directory
- // "go-foo" being package "foo", or "pkg.v3" being "pkg",
- // or directory "google.golang.org/api/cloudbilling/v1"
- // being package "cloudbilling", but doesn't
- // permit a directory "foo" to be package
- // "bar", which is strongly discouraged
- // anyway. There's no reason goimports needs
- // to be slow just to accommodate that.
+ // Use the matchesPath heuristic to filter to package paths that could
+ // reasonably match a dangling reference.
+ //
+ // This permits mismatch naming like directory "go-foo" being package "foo",
+ // or "pkg.v3" being "pkg", or directory
+ // "google.golang.org/api/cloudbilling/v1" being package "cloudbilling", but
+ // doesn't permit a directory "foo" to be package "bar", which is strongly
+ // discouraged anyway. There's no reason goimports needs to be slow just to
+ // accommodate that.
for pkgIdent := range refs {
- lastTwo := lastTwoComponents(pkg.importPathShort)
- if strings.Contains(lastTwo, pkgIdent) {
+ if matchesPath(pkgIdent, pkg.importPathShort) {
return true
}
- if hasHyphenOrUpperASCII(lastTwo) && !hasHyphenOrUpperASCII(pkgIdent) {
- lastTwo = lowerASCIIAndRemoveHyphen(lastTwo)
- if strings.Contains(lastTwo, pkgIdent) {
- return true
- }
- }
}
return false
}
-func hasHyphenOrUpperASCII(s string) bool {
- for i := 0; i < len(s); i++ {
- b := s[i]
- if b == '-' || ('A' <= b && b <= 'Z') {
- return true
- }
- }
- return false
-}
-
-func lowerASCIIAndRemoveHyphen(s string) (ret string) {
- buf := make([]byte, 0, len(s))
- for i := 0; i < len(s); i++ {
- b := s[i]
- switch {
- case b == '-':
- continue
- case 'A' <= b && b <= 'Z':
- buf = append(buf, b+('a'-'A'))
- default:
- buf = append(buf, b)
- }
- }
- return string(buf)
-}
-
// canUse reports whether the package in dir is usable from filename,
// respecting the Go "internal" and "vendor" visibility rules.
func canUse(filename, dir string) bool {
@@ -1863,19 +1799,84 @@ func canUse(filename, dir string) bool {
return !strings.Contains(relSlash, "/vendor/") && !strings.Contains(relSlash, "/internal/") && !strings.HasSuffix(relSlash, "/internal")
}
-// lastTwoComponents returns at most the last two path components
-// of v, using either / or \ as the path separator.
-func lastTwoComponents(v string) string {
+// matchesPath reports whether ident may match a potential package name
+// referred to by path, using heuristics to filter out unidiomatic package
+// names.
+//
+// Specifically, it checks whether either of the last two '/'- or '\'-delimited
+// path segments matches the identifier. The segment-matching heuristic must
+// allow for various conventions around segment naming, including go-foo,
+// foo-go, and foo.v3. To handle all of these, matching considers both (1) the
+// entire segment, ignoring '-' and '.', as well as (2) the last subsegment
+// separated by '-' or '.'. So the segment foo-go matches all of the following
+// identifiers: foo, go, and foogo. All matches are case insensitive (for ASCII
+// identifiers).
+//
+// See the docstring for [pkgIsCandidate] for an explanation of how this
+// heuristic filters potential candidate packages.
+func matchesPath(ident, path string) bool {
+ // Ignore case, for ASCII.
+ lowerIfASCII := func(b byte) byte {
+ if 'A' <= b && b <= 'Z' {
+ return b + ('a' - 'A')
+ }
+ return b
+ }
+
+ // match reports whether path[start:end] matches ident, ignoring [.-].
+ match := func(start, end int) bool {
+ ii := len(ident) - 1 // current byte in ident
+ pi := end - 1 // current byte in path
+ for ; pi >= start && ii >= 0; pi-- {
+ pb := path[pi]
+ if pb == '-' || pb == '.' {
+ continue
+ }
+ pb = lowerIfASCII(pb)
+ ib := lowerIfASCII(ident[ii])
+ if pb != ib {
+ return false
+ }
+ ii--
+ }
+ return ii < 0 && pi < start // all bytes matched
+ }
+
+ // segmentEnd and subsegmentEnd hold the end points of the current segment
+ // and subsegment intervals.
+ segmentEnd := len(path)
+ subsegmentEnd := len(path)
+
+ // Count slashes; we only care about the last two segments.
nslash := 0
- for i := len(v) - 1; i >= 0; i-- {
- if v[i] == '/' || v[i] == '\\' {
+
+ for i := len(path) - 1; i >= 0; i-- {
+ switch b := path[i]; b {
+ // TODO(rfindley): we handle backlashes here only because the previous
+ // heuristic handled backslashes. This is perhaps overly defensive, but is
+ // the result of many lessons regarding Chesterton's fence and the
+ // goimports codebase.
+ //
+ // However, this function is only ever called with something called an
+ // 'importPath'. Is it possible that this is a real import path, and
+ // therefore we need only consider forward slashes?
+ case '/', '\\':
+ if match(i+1, segmentEnd) || match(i+1, subsegmentEnd) {
+ return true
+ }
nslash++
if nslash == 2 {
- return v[i:]
+ return false // did not match above
+ }
+ segmentEnd, subsegmentEnd = i, i // reset
+ case '-', '.':
+ if match(i+1, subsegmentEnd) {
+ return true
}
+ subsegmentEnd = i
}
}
- return v
+ return match(0, segmentEnd) || match(0, subsegmentEnd)
}
type visitFn func(node ast.Node) ast.Visitor