diff --git a/internal/config/ssh.go b/internal/config/ssh.go index 6a4896c..eb8cbb9 100644 --- a/internal/config/ssh.go +++ b/internal/config/ssh.go @@ -96,25 +96,45 @@ func ParseSSHConfig() ([]SSHHost, error) { // ParseSSHConfigFile parses a specific SSH config file and returns the list of hosts func ParseSSHConfigFile(configPath string) ([]SSHHost, error) { + return parseSSHConfigFileWithProcessedFiles(configPath, make(map[string]bool)) +} + +// parseSSHConfigFileWithProcessedFiles parses SSH config with include support +func parseSSHConfigFileWithProcessedFiles(configPath string, processedFiles map[string]bool) ([]SSHHost, error) { + // Resolve absolute path to prevent infinite recursion + absPath, err := filepath.Abs(configPath) + if err != nil { + return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", configPath, err) + } + + // Check for circular includes + if processedFiles[absPath] { + return []SSHHost{}, nil // Skip already processed files silently + } + processedFiles[absPath] = true + // Check if the file exists, otherwise create it (and the parent directory if needed) if _, err := os.Stat(configPath); os.IsNotExist(err) { - // Ensure .ssh directory exists with proper permissions - if err := ensureSSHDirectory(); err != nil { - return nil, fmt.Errorf("failed to create .ssh directory: %w", err) + // Only create the main config file, not included files + if absPath == getMainConfigPath() { + // Ensure .ssh directory exists with proper permissions + if err := ensureSSHDirectory(); err != nil { + return nil, fmt.Errorf("failed to create .ssh directory: %w", err) + } + + file, err := os.OpenFile(configPath, os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + return nil, fmt.Errorf("failed to create SSH config file: %w", err) + } + file.Close() + + // Set secure permissions on the config file + if err := SetSecureFilePermissions(configPath); err != nil { + return nil, fmt.Errorf("failed to set secure permissions: %w", err) + } } - file, err := os.OpenFile(configPath, os.O_CREATE|os.O_WRONLY, 0600) - if err != nil { - return nil, fmt.Errorf("failed to create SSH config file: %w", err) - } - file.Close() - - // Set secure permissions on the config file - if err := SetSecureFilePermissions(configPath); err != nil { - return nil, fmt.Errorf("failed to set secure permissions: %w", err) - } - - // File created, return empty host list + // File doesn't exist, return empty host list return []SSHHost{}, nil } @@ -168,11 +188,25 @@ func ParseSSHConfigFile(configPath string) ([]SSHHost, error) { value := strings.Join(parts[1:], " ") switch key { + case "include": + // Handle Include directive + includeHosts, err := processIncludeDirective(value, configPath, processedFiles) + if err != nil { + // Don't fail the entire parse if include fails, just skip it + continue + } + hosts = append(hosts, includeHosts...) case "host": // New host, save previous one if it exists if currentHost != nil { hosts = append(hosts, *currentHost) } + // Skip hosts with wildcards (*, ?) as they are typically patterns, not actual hosts + if strings.ContainsAny(value, "*?") { + currentHost = nil + pendingTags = nil + continue + } // Create new host currentHost = &SSHHost{ Name: value, @@ -222,6 +256,55 @@ func ParseSSHConfigFile(configPath string) ([]SSHHost, error) { return hosts, scanner.Err() } +// processIncludeDirective processes an Include directive and returns hosts from included files +func processIncludeDirective(pattern string, baseConfigPath string, processedFiles map[string]bool) ([]SSHHost, error) { + // Expand tilde to home directory + if strings.HasPrefix(pattern, "~") { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + pattern = filepath.Join(homeDir, pattern[1:]) + } + + // If pattern is not absolute, make it relative to the base config directory + if !filepath.IsAbs(pattern) { + baseDir := filepath.Dir(baseConfigPath) + pattern = filepath.Join(baseDir, pattern) + } + + // Use glob to find matching files + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, fmt.Errorf("failed to glob pattern %s: %w", pattern, err) + } + + var allHosts []SSHHost + for _, match := range matches { + // Skip directories + if info, err := os.Stat(match); err == nil && info.IsDir() { + continue + } + + // Recursively parse the included file + hosts, err := parseSSHConfigFileWithProcessedFiles(match, processedFiles) + if err != nil { + // Skip files that can't be parsed rather than failing completely + continue + } + allHosts = append(allHosts, hosts...) + } + + return allHosts, nil +} + +// getMainConfigPath returns the main SSH config path for comparison +func getMainConfigPath() string { + configPath, _ := GetDefaultSSHConfigPath() + absPath, _ := filepath.Abs(configPath) + return absPath +} + // AddSSHHost adds a new SSH host to the config file func AddSSHHost(host SSHHost) error { configPath, err := GetDefaultSSHConfigPath() diff --git a/internal/config/ssh_test.go b/internal/config/ssh_test.go index 1256816..d13128d 100644 --- a/internal/config/ssh_test.go +++ b/internal/config/ssh_test.go @@ -1,6 +1,7 @@ package config import ( + "os" "path/filepath" "runtime" "strings" @@ -71,3 +72,296 @@ func TestEnsureSSHDirectory(t *testing.T) { t.Fatalf("ensureSSHDirectory() error = %v", err) } } + +func TestParseSSHConfigWithInclude(t *testing.T) { + // Create temporary directory for test files + tempDir := t.TempDir() + + // Create main config file + mainConfig := filepath.Join(tempDir, "config") + mainConfigContent := `Host main-host + HostName example.com + User mainuser + +Include included.conf +Include subdir/* + +Host another-host + HostName another.example.com + User anotheruser +` + + err := os.WriteFile(mainConfig, []byte(mainConfigContent), 0600) + if err != nil { + t.Fatalf("Failed to create main config: %v", err) + } + + // Create included file + includedConfig := filepath.Join(tempDir, "included.conf") + includedConfigContent := `Host included-host + HostName included.example.com + User includeduser + Port 2222 +` + + err = os.WriteFile(includedConfig, []byte(includedConfigContent), 0600) + if err != nil { + t.Fatalf("Failed to create included config: %v", err) + } + + // Create subdirectory with another config file + subDir := filepath.Join(tempDir, "subdir") + err = os.MkdirAll(subDir, 0700) + if err != nil { + t.Fatalf("Failed to create subdir: %v", err) + } + + subConfig := filepath.Join(subDir, "sub.conf") + subConfigContent := `Host sub-host + HostName sub.example.com + User subuser + IdentityFile ~/.ssh/sub_key +` + + err = os.WriteFile(subConfig, []byte(subConfigContent), 0600) + if err != nil { + t.Fatalf("Failed to create sub config: %v", err) + } + + // Parse the main config file + hosts, err := ParseSSHConfigFile(mainConfig) + if err != nil { + t.Fatalf("ParseSSHConfigFile() error = %v", err) + } + + // Check that we got all expected hosts + expectedHosts := map[string]struct{}{ + "main-host": {}, + "included-host": {}, + "sub-host": {}, + "another-host": {}, + } + + if len(hosts) != len(expectedHosts) { + t.Errorf("Expected %d hosts, got %d", len(expectedHosts), len(hosts)) + } + + for _, host := range hosts { + if _, exists := expectedHosts[host.Name]; !exists { + t.Errorf("Unexpected host found: %s", host.Name) + } + delete(expectedHosts, host.Name) + + // Validate specific host properties + switch host.Name { + case "main-host": + if host.Hostname != "example.com" || host.User != "mainuser" { + t.Errorf("main-host properties incorrect: hostname=%s, user=%s", host.Hostname, host.User) + } + case "included-host": + if host.Hostname != "included.example.com" || host.User != "includeduser" || host.Port != "2222" { + t.Errorf("included-host properties incorrect: hostname=%s, user=%s, port=%s", host.Hostname, host.User, host.Port) + } + case "sub-host": + if host.Hostname != "sub.example.com" || host.User != "subuser" || host.Identity != "~/.ssh/sub_key" { + t.Errorf("sub-host properties incorrect: hostname=%s, user=%s, identity=%s", host.Hostname, host.User, host.Identity) + } + case "another-host": + if host.Hostname != "another.example.com" || host.User != "anotheruser" { + t.Errorf("another-host properties incorrect: hostname=%s, user=%s", host.Hostname, host.User) + } + } + } + + // Check that all expected hosts were found + if len(expectedHosts) > 0 { + var missing []string + for host := range expectedHosts { + missing = append(missing, host) + } + t.Errorf("Missing hosts: %v", missing) + } +} + +func TestParseSSHConfigWithCircularInclude(t *testing.T) { + // Create temporary directory for test files + tempDir := t.TempDir() + + // Create config1 that includes config2 + config1 := filepath.Join(tempDir, "config1") + config1Content := `Host host1 + HostName example1.com + +Include config2 +` + + err := os.WriteFile(config1, []byte(config1Content), 0600) + if err != nil { + t.Fatalf("Failed to create config1: %v", err) + } + + // Create config2 that includes config1 (circular) + config2 := filepath.Join(tempDir, "config2") + config2Content := `Host host2 + HostName example2.com + +Include config1 +` + + err = os.WriteFile(config2, []byte(config2Content), 0600) + if err != nil { + t.Fatalf("Failed to create config2: %v", err) + } + + // Parse the config file - should not cause infinite recursion + hosts, err := ParseSSHConfigFile(config1) + if err != nil { + t.Fatalf("ParseSSHConfigFile() error = %v", err) + } + + // Should get both hosts exactly once + expectedHosts := map[string]bool{ + "host1": false, + "host2": false, + } + + for _, host := range hosts { + if _, exists := expectedHosts[host.Name]; !exists { + t.Errorf("Unexpected host found: %s", host.Name) + } else { + if expectedHosts[host.Name] { + t.Errorf("Host %s found multiple times", host.Name) + } + expectedHosts[host.Name] = true + } + } + + // Check all hosts were found + for hostName, found := range expectedHosts { + if !found { + t.Errorf("Host %s not found", hostName) + } + } +} + +func TestParseSSHConfigWithNonExistentInclude(t *testing.T) { + // Create temporary directory for test files + tempDir := t.TempDir() + + // Create main config file with non-existent include + mainConfig := filepath.Join(tempDir, "config") + mainConfigContent := `Host main-host + HostName example.com + +Include non-existent-file.conf + +Host another-host + HostName another.example.com +` + + err := os.WriteFile(mainConfig, []byte(mainConfigContent), 0600) + if err != nil { + t.Fatalf("Failed to create main config: %v", err) + } + + // Parse should succeed and ignore the non-existent include + hosts, err := ParseSSHConfigFile(mainConfig) + if err != nil { + t.Fatalf("ParseSSHConfigFile() error = %v", err) + } + + // Should get the hosts that exist, ignoring the failed include + if len(hosts) != 2 { + t.Errorf("Expected 2 hosts, got %d", len(hosts)) + } + + hostNames := make(map[string]bool) + for _, host := range hosts { + hostNames[host.Name] = true + } + + if !hostNames["main-host"] || !hostNames["another-host"] { + t.Errorf("Expected main-host and another-host, got: %v", hostNames) + } +} + +func TestParseSSHConfigWithWildcardHosts(t *testing.T) { + // Create temporary directory for test files + tempDir := t.TempDir() + + // Create config file with wildcard hosts + configFile := filepath.Join(tempDir, "config") + configContent := `# Wildcard patterns should be ignored +Host *.example.com + User defaultuser + IdentityFile ~/.ssh/id_rsa + +Host server-* + Port 2222 + +Host * + ServerAliveInterval 60 + +# Real hosts should be included +Host real-server + HostName real.example.com + User realuser + +Host another-real-server + HostName another.example.com + User anotheruser +` + + err := os.WriteFile(configFile, []byte(configContent), 0600) + if err != nil { + t.Fatalf("Failed to create config: %v", err) + } + + // Parse the config file + hosts, err := ParseSSHConfigFile(configFile) + if err != nil { + t.Fatalf("ParseSSHConfigFile() error = %v", err) + } + + // Should only get real hosts, not wildcard patterns + expectedHosts := map[string]bool{ + "real-server": false, + "another-real-server": false, + } + + if len(hosts) != len(expectedHosts) { + t.Errorf("Expected %d hosts, got %d", len(expectedHosts), len(hosts)) + for _, host := range hosts { + t.Logf("Found host: %s", host.Name) + } + } + + for _, host := range hosts { + if _, expected := expectedHosts[host.Name]; !expected { + t.Errorf("Unexpected host found: %s", host.Name) + } else { + expectedHosts[host.Name] = true + } + } + + // Check that all expected hosts were found + for hostName, found := range expectedHosts { + if !found { + t.Errorf("Expected host %s not found", hostName) + } + } + + // Verify host properties + for _, host := range hosts { + switch host.Name { + case "real-server": + if host.Hostname != "real.example.com" || host.User != "realuser" { + t.Errorf("real-server properties incorrect: hostname=%s, user=%s", host.Hostname, host.User) + } + case "another-real-server": + if host.Hostname != "another.example.com" || host.User != "anotheruser" { + t.Errorf("another-real-server properties incorrect: hostname=%s, user=%s", host.Hostname, host.User) + } + } + } +}