mirror of
https://github.com/Gu1llaum-3/sshm.git
synced 2025-10-19 01:17:20 +02:00
feat: filter non-SSH files from config parsing
- Skip README, .git, and documentation files during SSH config parsing - Add QuickHostExists for fast host validation without full parsing - Prevent errors when Include * encounters non-config files
This commit is contained in:
parent
42e87b6827
commit
6ba82b1c97
19
cmd/root.go
19
cmd/root.go
@ -103,27 +103,18 @@ func runInteractiveMode() {
|
||||
}
|
||||
|
||||
func connectToHost(hostName string) {
|
||||
// Parse SSH configurations to verify host exists
|
||||
var hosts []config.SSHHost
|
||||
// Quick check if host exists without full parsing (optimized for connection)
|
||||
var hostFound bool
|
||||
var err error
|
||||
|
||||
if configFile != "" {
|
||||
hosts, err = config.ParseSSHConfigFile(configFile)
|
||||
hostFound, err = config.QuickHostExistsInFile(hostName, configFile)
|
||||
} else {
|
||||
hosts, err = config.ParseSSHConfig()
|
||||
hostFound, err = config.QuickHostExists(hostName)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Error reading SSH config file: %v", err)
|
||||
}
|
||||
|
||||
// Check if host exists
|
||||
var hostFound bool
|
||||
for _, host := range hosts {
|
||||
if host.Name == hostName {
|
||||
hostFound = true
|
||||
break
|
||||
}
|
||||
log.Fatalf("Error checking SSH config: %v", err)
|
||||
}
|
||||
|
||||
if !hostFound {
|
||||
|
@ -399,6 +399,11 @@ func processIncludeDirective(pattern string, baseConfigPath string, processedFil
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip common non-SSH config file types
|
||||
if isNonSSHConfigFile(match) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Recursively parse the included file
|
||||
hosts, err := parseSSHConfigFileWithProcessedFiles(match, processedFiles)
|
||||
if err != nil {
|
||||
@ -411,6 +416,82 @@ func processIncludeDirective(pattern string, baseConfigPath string, processedFil
|
||||
return allHosts, nil
|
||||
}
|
||||
|
||||
// isNonSSHConfigFile checks if a file should be excluded from SSH config parsing
|
||||
func isNonSSHConfigFile(filePath string) bool {
|
||||
fileName := strings.ToLower(filepath.Base(filePath))
|
||||
|
||||
// Skip common documentation files
|
||||
if fileName == "readme" || fileName == "readme.txt" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip files with common non-config extensions
|
||||
excludedExtensions := []string{
|
||||
".txt", ".md", ".rst", ".doc", ".docx", ".pdf",
|
||||
".log", ".tmp", ".bak", ".old", ".orig",
|
||||
".json", ".xml", ".yaml", ".yml", ".toml",
|
||||
".sh", ".bash", ".zsh", ".fish", ".ps1", ".bat", ".cmd",
|
||||
".py", ".pl", ".rb", ".js", ".php", ".go", ".c", ".cpp",
|
||||
".jpg", ".jpeg", ".png", ".gif", ".bmp", ".svg",
|
||||
".zip", ".tar", ".gz", ".bz2", ".xz",
|
||||
}
|
||||
|
||||
for _, ext := range excludedExtensions {
|
||||
if strings.HasSuffix(fileName, ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Skip hidden files (starting with .)
|
||||
if strings.HasPrefix(fileName, ".") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Additional check: if file contains common non-SSH content indicators
|
||||
// This is a more expensive check, so we do it last
|
||||
if hasNonSSHContent(filePath) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// hasNonSSHContent performs a quick content check to identify non-SSH files
|
||||
func hasNonSSHContent(filePath string) bool {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return false // If we can't read it, don't exclude it
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Read only the first few KB to check content
|
||||
buffer := make([]byte, 2048)
|
||||
n, err := file.Read(buffer)
|
||||
if err != nil && err != io.EOF {
|
||||
return false
|
||||
}
|
||||
|
||||
content := strings.ToLower(string(buffer[:n]))
|
||||
|
||||
// Check for common non-SSH file indicators
|
||||
nonSSHIndicators := []string{
|
||||
"<!doctype", "<html>", "<xml>", "<?xml",
|
||||
"#!/bin/", "#!/usr/bin/",
|
||||
"# readme", "# documentation", "# license",
|
||||
"package main", "function ", "class ", "def ",
|
||||
"import ", "require ", "#include",
|
||||
"SELECT ", "INSERT ", "UPDATE ", "DELETE ",
|
||||
}
|
||||
|
||||
for _, indicator := range nonSSHIndicators {
|
||||
if strings.Contains(content, indicator) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getMainConfigPath returns the main SSH config path for comparison
|
||||
func getMainConfigPath() string {
|
||||
configPath, _ := GetDefaultSSHConfigPath()
|
||||
@ -677,6 +758,131 @@ func GetSSHHostFromFile(hostName string, configPath string) (*SSHHost, error) {
|
||||
return nil, fmt.Errorf("host '%s' not found", hostName)
|
||||
}
|
||||
|
||||
// QuickHostExists performs a fast check if a host exists without full parsing
|
||||
// This is optimized for connection scenarios where we just need to verify existence
|
||||
func QuickHostExists(hostName string) (bool, error) {
|
||||
configPath, err := GetDefaultSSHConfigPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return QuickHostExistsInFile(hostName, configPath)
|
||||
}
|
||||
|
||||
// QuickHostExistsInFile performs a fast check if a host exists in config files
|
||||
// This stops parsing as soon as the host is found, making it much faster for connection scenarios
|
||||
func QuickHostExistsInFile(hostName string, configPath string) (bool, error) {
|
||||
return quickHostSearchInFile(hostName, configPath, make(map[string]bool))
|
||||
}
|
||||
|
||||
// quickHostSearchInFile performs optimized host search with early termination
|
||||
func quickHostSearchInFile(hostName string, configPath string, processedFiles map[string]bool) (bool, error) {
|
||||
// Resolve absolute path to prevent infinite recursion
|
||||
absPath, err := filepath.Abs(configPath)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to resolve absolute path for %s: %w", configPath, err)
|
||||
}
|
||||
|
||||
// Check for circular includes
|
||||
if processedFiles[absPath] {
|
||||
return false, nil // Skip already processed files silently
|
||||
}
|
||||
processedFiles[absPath] = true
|
||||
|
||||
// Check if the file exists
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
return false, nil // File doesn't exist, host not found
|
||||
}
|
||||
|
||||
file, err := os.Open(configPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// Ignore empty lines and comments (except includes)
|
||||
if line == "" || (strings.HasPrefix(line, "#") && !strings.HasPrefix(line, "# Tags:")) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Split line into words
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.ToLower(parts[0])
|
||||
value := strings.Join(parts[1:], " ")
|
||||
|
||||
switch key {
|
||||
case "include":
|
||||
// Handle Include directive - search in included files
|
||||
if found, err := quickSearchInclude(hostName, value, configPath, processedFiles); err == nil && found {
|
||||
return true, nil // Found in included file
|
||||
}
|
||||
case "host":
|
||||
// Parse multiple host names from the Host line
|
||||
hostNames := strings.Fields(value)
|
||||
|
||||
// Check if our target host is in this Host declaration
|
||||
for _, candidateHostName := range hostNames {
|
||||
// Skip hosts with wildcards (*, ?) as they are typically patterns
|
||||
if !strings.ContainsAny(candidateHostName, "*?") && candidateHostName == hostName {
|
||||
return true, nil // Found the host!
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, scanner.Err()
|
||||
}
|
||||
|
||||
// quickSearchInclude handles Include directives during quick host search
|
||||
func quickSearchInclude(hostName, pattern, baseConfigPath string, processedFiles map[string]bool) (bool, error) {
|
||||
// Expand tilde to home directory
|
||||
if strings.HasPrefix(pattern, "~") {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
pattern = filepath.Join(homeDir, pattern[1:])
|
||||
}
|
||||
|
||||
// If pattern is not absolute, make it relative to the base config directory
|
||||
if !filepath.IsAbs(pattern) {
|
||||
baseDir := filepath.Dir(baseConfigPath)
|
||||
pattern = filepath.Join(baseDir, pattern)
|
||||
}
|
||||
|
||||
// Use glob to find matching files
|
||||
matches, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to glob pattern %s: %w", pattern, err)
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
// Skip directories
|
||||
if info, err := os.Stat(match); err == nil && info.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip non-SSH config files (this avoids parsing README, etc.)
|
||||
if isNonSSHConfigFile(match) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Search in the included file
|
||||
if found, err := quickHostSearchInFile(hostName, match, processedFiles); err == nil && found {
|
||||
return true, nil // Found in this included file
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// UpdateSSHHost updates an existing SSH host configuration
|
||||
func UpdateSSHHost(oldName string, newHost SSHHost) error {
|
||||
return UpdateSSHHostV2(oldName, newHost)
|
||||
|
@ -1548,3 +1548,149 @@ func TestAddSSHHostWithSpacesInPath(t *testing.T) {
|
||||
t.Errorf("Expected identity file line with quotes not found.\nContent:\n%s\nExpected line: %s", contentStr, expectedIdentityLine)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNonSSHConfigFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
fileName string
|
||||
expected bool
|
||||
}{
|
||||
// Should be excluded
|
||||
{"README", true},
|
||||
{"README.txt", true},
|
||||
{"README.md", true},
|
||||
{"script.sh", true},
|
||||
{"data.json", true},
|
||||
{"notes.txt", true},
|
||||
{".gitignore", true},
|
||||
{"backup.bak", true},
|
||||
{"old.orig", true},
|
||||
{"log.log", true},
|
||||
{"temp.tmp", true},
|
||||
{"archive.zip", true},
|
||||
{"image.jpg", true},
|
||||
{"python.py", true},
|
||||
{"golang.go", true},
|
||||
{"config.yaml", true},
|
||||
{"config.yml", true},
|
||||
{"config.toml", true},
|
||||
|
||||
// Should NOT be excluded (valid SSH config files)
|
||||
{"config", false},
|
||||
{"servers.conf", false},
|
||||
{"production", false},
|
||||
{"staging", false},
|
||||
{"hosts", false},
|
||||
{"ssh_config", false},
|
||||
{"work-servers", false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
// Create a temporary file for content testing
|
||||
tempDir := t.TempDir()
|
||||
filePath := filepath.Join(tempDir, test.fileName)
|
||||
|
||||
// Write appropriate content based on expected result
|
||||
var content string
|
||||
if test.expected {
|
||||
// Write non-SSH content for files that should be excluded
|
||||
content = "# This is not an SSH config file\nSome random content"
|
||||
} else {
|
||||
// Write SSH-like content for files that should be included
|
||||
content = "Host example\n HostName example.com\n User testuser"
|
||||
}
|
||||
|
||||
err := os.WriteFile(filePath, []byte(content), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test file %s: %v", test.fileName, err)
|
||||
}
|
||||
|
||||
result := isNonSSHConfigFile(filePath)
|
||||
if result != test.expected {
|
||||
t.Errorf("isNonSSHConfigFile(%q) = %v, want %v", test.fileName, result, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuickHostExists(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create main config file
|
||||
mainConfig := filepath.Join(tempDir, "config")
|
||||
mainConfigContent := `Host main-host
|
||||
HostName example.com
|
||||
|
||||
Include config.d/*
|
||||
|
||||
Host another-host
|
||||
HostName another.example.com
|
||||
`
|
||||
|
||||
err := os.WriteFile(mainConfig, []byte(mainConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create main config: %v", err)
|
||||
}
|
||||
|
||||
// Create config.d directory
|
||||
configDir := filepath.Join(tempDir, "config.d")
|
||||
err = os.MkdirAll(configDir, 0700)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create config.d: %v", err)
|
||||
}
|
||||
|
||||
// Create valid SSH config file in config.d
|
||||
validConfig := filepath.Join(configDir, "servers.conf")
|
||||
validConfigContent := `Host included-host
|
||||
HostName included.example.com
|
||||
User includeduser
|
||||
|
||||
Host production-server
|
||||
HostName prod.example.com
|
||||
User produser
|
||||
`
|
||||
|
||||
err = os.WriteFile(validConfig, []byte(validConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create valid config: %v", err)
|
||||
}
|
||||
|
||||
// Create files that should be excluded (README, etc.)
|
||||
excludedFiles := map[string]string{
|
||||
"README": "# This is a README file\nDocumentation goes here",
|
||||
"README.md": "# SSH Configuration\nThis directory contains...",
|
||||
"script.sh": "#!/bin/bash\necho 'hello world'",
|
||||
"data.json": `{"key": "value"}`,
|
||||
}
|
||||
|
||||
for fileName, content := range excludedFiles {
|
||||
filePath := filepath.Join(configDir, fileName)
|
||||
err = os.WriteFile(filePath, []byte(content), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create %s: %v", fileName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test hosts that should be found
|
||||
existingHosts := []string{"main-host", "another-host", "included-host", "production-server"}
|
||||
for _, hostName := range existingHosts {
|
||||
found, err := QuickHostExistsInFile(hostName, mainConfig)
|
||||
if err != nil {
|
||||
t.Errorf("QuickHostExistsInFile(%q) error = %v", hostName, err)
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("QuickHostExistsInFile(%q) = false, want true", hostName)
|
||||
}
|
||||
}
|
||||
|
||||
// Test hosts that should NOT be found
|
||||
nonExistingHosts := []string{"nonexistent-host", "fake-server", "unknown"}
|
||||
for _, hostName := range nonExistingHosts {
|
||||
found, err := QuickHostExistsInFile(hostName, mainConfig)
|
||||
if err != nil {
|
||||
t.Errorf("QuickHostExistsInFile(%q) error = %v", hostName, err)
|
||||
}
|
||||
if found {
|
||||
t.Errorf("QuickHostExistsInFile(%q) = true, want false", hostName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user