mirror of
https://github.com/Gu1llaum-3/sshm.git
synced 2025-09-07 21:30:39 +02:00
add ssh config include
This commit is contained in:
parent
2ade315ddc
commit
e0dd32993a
@ -96,25 +96,45 @@ func ParseSSHConfig() ([]SSHHost, error) {
|
|||||||
|
|
||||||
// ParseSSHConfigFile parses a specific SSH config file and returns the list of hosts
|
// ParseSSHConfigFile parses a specific SSH config file and returns the list of hosts
|
||||||
func ParseSSHConfigFile(configPath string) ([]SSHHost, error) {
|
func ParseSSHConfigFile(configPath string) ([]SSHHost, error) {
|
||||||
|
return parseSSHConfigFileWithProcessedFiles(configPath, make(map[string]bool))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSSHConfigFileWithProcessedFiles parses SSH config with include support
|
||||||
|
func parseSSHConfigFileWithProcessedFiles(configPath string, processedFiles map[string]bool) ([]SSHHost, error) {
|
||||||
|
// Resolve absolute path to prevent infinite recursion
|
||||||
|
absPath, err := filepath.Abs(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", configPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for circular includes
|
||||||
|
if processedFiles[absPath] {
|
||||||
|
return []SSHHost{}, nil // Skip already processed files silently
|
||||||
|
}
|
||||||
|
processedFiles[absPath] = true
|
||||||
|
|
||||||
// Check if the file exists, otherwise create it (and the parent directory if needed)
|
// Check if the file exists, otherwise create it (and the parent directory if needed)
|
||||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||||
// Ensure .ssh directory exists with proper permissions
|
// Only create the main config file, not included files
|
||||||
if err := ensureSSHDirectory(); err != nil {
|
if absPath == getMainConfigPath() {
|
||||||
return nil, fmt.Errorf("failed to create .ssh directory: %w", err)
|
// Ensure .ssh directory exists with proper permissions
|
||||||
|
if err := ensureSSHDirectory(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create .ssh directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.OpenFile(configPath, os.O_CREATE|os.O_WRONLY, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create SSH config file: %w", err)
|
||||||
|
}
|
||||||
|
file.Close()
|
||||||
|
|
||||||
|
// Set secure permissions on the config file
|
||||||
|
if err := SetSecureFilePermissions(configPath); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set secure permissions: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := os.OpenFile(configPath, os.O_CREATE|os.O_WRONLY, 0600)
|
// File doesn't exist, return empty host list
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create SSH config file: %w", err)
|
|
||||||
}
|
|
||||||
file.Close()
|
|
||||||
|
|
||||||
// Set secure permissions on the config file
|
|
||||||
if err := SetSecureFilePermissions(configPath); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set secure permissions: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// File created, return empty host list
|
|
||||||
return []SSHHost{}, nil
|
return []SSHHost{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,11 +188,25 @@ func ParseSSHConfigFile(configPath string) ([]SSHHost, error) {
|
|||||||
value := strings.Join(parts[1:], " ")
|
value := strings.Join(parts[1:], " ")
|
||||||
|
|
||||||
switch key {
|
switch key {
|
||||||
|
case "include":
|
||||||
|
// Handle Include directive
|
||||||
|
includeHosts, err := processIncludeDirective(value, configPath, processedFiles)
|
||||||
|
if err != nil {
|
||||||
|
// Don't fail the entire parse if include fails, just skip it
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hosts = append(hosts, includeHosts...)
|
||||||
case "host":
|
case "host":
|
||||||
// New host, save previous one if it exists
|
// New host, save previous one if it exists
|
||||||
if currentHost != nil {
|
if currentHost != nil {
|
||||||
hosts = append(hosts, *currentHost)
|
hosts = append(hosts, *currentHost)
|
||||||
}
|
}
|
||||||
|
// Skip hosts with wildcards (*, ?) as they are typically patterns, not actual hosts
|
||||||
|
if strings.ContainsAny(value, "*?") {
|
||||||
|
currentHost = nil
|
||||||
|
pendingTags = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
// Create new host
|
// Create new host
|
||||||
currentHost = &SSHHost{
|
currentHost = &SSHHost{
|
||||||
Name: value,
|
Name: value,
|
||||||
@ -222,6 +256,55 @@ func ParseSSHConfigFile(configPath string) ([]SSHHost, error) {
|
|||||||
return hosts, scanner.Err()
|
return hosts, scanner.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// processIncludeDirective processes an Include directive and returns hosts from included files
|
||||||
|
func processIncludeDirective(pattern string, baseConfigPath string, processedFiles map[string]bool) ([]SSHHost, error) {
|
||||||
|
// Expand tilde to home directory
|
||||||
|
if strings.HasPrefix(pattern, "~") {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
pattern = filepath.Join(homeDir, pattern[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// If pattern is not absolute, make it relative to the base config directory
|
||||||
|
if !filepath.IsAbs(pattern) {
|
||||||
|
baseDir := filepath.Dir(baseConfigPath)
|
||||||
|
pattern = filepath.Join(baseDir, pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use glob to find matching files
|
||||||
|
matches, err := filepath.Glob(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to glob pattern %s: %w", pattern, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allHosts []SSHHost
|
||||||
|
for _, match := range matches {
|
||||||
|
// Skip directories
|
||||||
|
if info, err := os.Stat(match); err == nil && info.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively parse the included file
|
||||||
|
hosts, err := parseSSHConfigFileWithProcessedFiles(match, processedFiles)
|
||||||
|
if err != nil {
|
||||||
|
// Skip files that can't be parsed rather than failing completely
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
allHosts = append(allHosts, hosts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return allHosts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMainConfigPath returns the main SSH config path for comparison
|
||||||
|
func getMainConfigPath() string {
|
||||||
|
configPath, _ := GetDefaultSSHConfigPath()
|
||||||
|
absPath, _ := filepath.Abs(configPath)
|
||||||
|
return absPath
|
||||||
|
}
|
||||||
|
|
||||||
// AddSSHHost adds a new SSH host to the config file
|
// AddSSHHost adds a new SSH host to the config file
|
||||||
func AddSSHHost(host SSHHost) error {
|
func AddSSHHost(host SSHHost) error {
|
||||||
configPath, err := GetDefaultSSHConfigPath()
|
configPath, err := GetDefaultSSHConfigPath()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
@ -71,3 +72,296 @@ func TestEnsureSSHDirectory(t *testing.T) {
|
|||||||
t.Fatalf("ensureSSHDirectory() error = %v", err)
|
t.Fatalf("ensureSSHDirectory() error = %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseSSHConfigWithInclude(t *testing.T) {
|
||||||
|
// Create temporary directory for test files
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create main config file
|
||||||
|
mainConfig := filepath.Join(tempDir, "config")
|
||||||
|
mainConfigContent := `Host main-host
|
||||||
|
HostName example.com
|
||||||
|
User mainuser
|
||||||
|
|
||||||
|
Include included.conf
|
||||||
|
Include subdir/*
|
||||||
|
|
||||||
|
Host another-host
|
||||||
|
HostName another.example.com
|
||||||
|
User anotheruser
|
||||||
|
`
|
||||||
|
|
||||||
|
err := os.WriteFile(mainConfig, []byte(mainConfigContent), 0600)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create main config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create included file
|
||||||
|
includedConfig := filepath.Join(tempDir, "included.conf")
|
||||||
|
includedConfigContent := `Host included-host
|
||||||
|
HostName included.example.com
|
||||||
|
User includeduser
|
||||||
|
Port 2222
|
||||||
|
`
|
||||||
|
|
||||||
|
err = os.WriteFile(includedConfig, []byte(includedConfigContent), 0600)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create included config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create subdirectory with another config file
|
||||||
|
subDir := filepath.Join(tempDir, "subdir")
|
||||||
|
err = os.MkdirAll(subDir, 0700)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create subdir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
subConfig := filepath.Join(subDir, "sub.conf")
|
||||||
|
subConfigContent := `Host sub-host
|
||||||
|
HostName sub.example.com
|
||||||
|
User subuser
|
||||||
|
IdentityFile ~/.ssh/sub_key
|
||||||
|
`
|
||||||
|
|
||||||
|
err = os.WriteFile(subConfig, []byte(subConfigContent), 0600)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create sub config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the main config file
|
||||||
|
hosts, err := ParseSSHConfigFile(mainConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseSSHConfigFile() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we got all expected hosts
|
||||||
|
expectedHosts := map[string]struct{}{
|
||||||
|
"main-host": {},
|
||||||
|
"included-host": {},
|
||||||
|
"sub-host": {},
|
||||||
|
"another-host": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(hosts) != len(expectedHosts) {
|
||||||
|
t.Errorf("Expected %d hosts, got %d", len(expectedHosts), len(hosts))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, host := range hosts {
|
||||||
|
if _, exists := expectedHosts[host.Name]; !exists {
|
||||||
|
t.Errorf("Unexpected host found: %s", host.Name)
|
||||||
|
}
|
||||||
|
delete(expectedHosts, host.Name)
|
||||||
|
|
||||||
|
// Validate specific host properties
|
||||||
|
switch host.Name {
|
||||||
|
case "main-host":
|
||||||
|
if host.Hostname != "example.com" || host.User != "mainuser" {
|
||||||
|
t.Errorf("main-host properties incorrect: hostname=%s, user=%s", host.Hostname, host.User)
|
||||||
|
}
|
||||||
|
case "included-host":
|
||||||
|
if host.Hostname != "included.example.com" || host.User != "includeduser" || host.Port != "2222" {
|
||||||
|
t.Errorf("included-host properties incorrect: hostname=%s, user=%s, port=%s", host.Hostname, host.User, host.Port)
|
||||||
|
}
|
||||||
|
case "sub-host":
|
||||||
|
if host.Hostname != "sub.example.com" || host.User != "subuser" || host.Identity != "~/.ssh/sub_key" {
|
||||||
|
t.Errorf("sub-host properties incorrect: hostname=%s, user=%s, identity=%s", host.Hostname, host.User, host.Identity)
|
||||||
|
}
|
||||||
|
case "another-host":
|
||||||
|
if host.Hostname != "another.example.com" || host.User != "anotheruser" {
|
||||||
|
t.Errorf("another-host properties incorrect: hostname=%s, user=%s", host.Hostname, host.User)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that all expected hosts were found
|
||||||
|
if len(expectedHosts) > 0 {
|
||||||
|
var missing []string
|
||||||
|
for host := range expectedHosts {
|
||||||
|
missing = append(missing, host)
|
||||||
|
}
|
||||||
|
t.Errorf("Missing hosts: %v", missing)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSSHConfigWithCircularInclude(t *testing.T) {
|
||||||
|
// Create temporary directory for test files
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create config1 that includes config2
|
||||||
|
config1 := filepath.Join(tempDir, "config1")
|
||||||
|
config1Content := `Host host1
|
||||||
|
HostName example1.com
|
||||||
|
|
||||||
|
Include config2
|
||||||
|
`
|
||||||
|
|
||||||
|
err := os.WriteFile(config1, []byte(config1Content), 0600)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create config1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create config2 that includes config1 (circular)
|
||||||
|
config2 := filepath.Join(tempDir, "config2")
|
||||||
|
config2Content := `Host host2
|
||||||
|
HostName example2.com
|
||||||
|
|
||||||
|
Include config1
|
||||||
|
`
|
||||||
|
|
||||||
|
err = os.WriteFile(config2, []byte(config2Content), 0600)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create config2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the config file - should not cause infinite recursion
|
||||||
|
hosts, err := ParseSSHConfigFile(config1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseSSHConfigFile() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should get both hosts exactly once
|
||||||
|
expectedHosts := map[string]bool{
|
||||||
|
"host1": false,
|
||||||
|
"host2": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, host := range hosts {
|
||||||
|
if _, exists := expectedHosts[host.Name]; !exists {
|
||||||
|
t.Errorf("Unexpected host found: %s", host.Name)
|
||||||
|
} else {
|
||||||
|
if expectedHosts[host.Name] {
|
||||||
|
t.Errorf("Host %s found multiple times", host.Name)
|
||||||
|
}
|
||||||
|
expectedHosts[host.Name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check all hosts were found
|
||||||
|
for hostName, found := range expectedHosts {
|
||||||
|
if !found {
|
||||||
|
t.Errorf("Host %s not found", hostName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSSHConfigWithNonExistentInclude(t *testing.T) {
|
||||||
|
// Create temporary directory for test files
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create main config file with non-existent include
|
||||||
|
mainConfig := filepath.Join(tempDir, "config")
|
||||||
|
mainConfigContent := `Host main-host
|
||||||
|
HostName example.com
|
||||||
|
|
||||||
|
Include non-existent-file.conf
|
||||||
|
|
||||||
|
Host another-host
|
||||||
|
HostName another.example.com
|
||||||
|
`
|
||||||
|
|
||||||
|
err := os.WriteFile(mainConfig, []byte(mainConfigContent), 0600)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create main config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse should succeed and ignore the non-existent include
|
||||||
|
hosts, err := ParseSSHConfigFile(mainConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseSSHConfigFile() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should get the hosts that exist, ignoring the failed include
|
||||||
|
if len(hosts) != 2 {
|
||||||
|
t.Errorf("Expected 2 hosts, got %d", len(hosts))
|
||||||
|
}
|
||||||
|
|
||||||
|
hostNames := make(map[string]bool)
|
||||||
|
for _, host := range hosts {
|
||||||
|
hostNames[host.Name] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hostNames["main-host"] || !hostNames["another-host"] {
|
||||||
|
t.Errorf("Expected main-host and another-host, got: %v", hostNames)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSSHConfigWithWildcardHosts(t *testing.T) {
|
||||||
|
// Create temporary directory for test files
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create config file with wildcard hosts
|
||||||
|
configFile := filepath.Join(tempDir, "config")
|
||||||
|
configContent := `# Wildcard patterns should be ignored
|
||||||
|
Host *.example.com
|
||||||
|
User defaultuser
|
||||||
|
IdentityFile ~/.ssh/id_rsa
|
||||||
|
|
||||||
|
Host server-*
|
||||||
|
Port 2222
|
||||||
|
|
||||||
|
Host *
|
||||||
|
ServerAliveInterval 60
|
||||||
|
|
||||||
|
# Real hosts should be included
|
||||||
|
Host real-server
|
||||||
|
HostName real.example.com
|
||||||
|
User realuser
|
||||||
|
|
||||||
|
Host another-real-server
|
||||||
|
HostName another.example.com
|
||||||
|
User anotheruser
|
||||||
|
`
|
||||||
|
|
||||||
|
err := os.WriteFile(configFile, []byte(configContent), 0600)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the config file
|
||||||
|
hosts, err := ParseSSHConfigFile(configFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseSSHConfigFile() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should only get real hosts, not wildcard patterns
|
||||||
|
expectedHosts := map[string]bool{
|
||||||
|
"real-server": false,
|
||||||
|
"another-real-server": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(hosts) != len(expectedHosts) {
|
||||||
|
t.Errorf("Expected %d hosts, got %d", len(expectedHosts), len(hosts))
|
||||||
|
for _, host := range hosts {
|
||||||
|
t.Logf("Found host: %s", host.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, host := range hosts {
|
||||||
|
if _, expected := expectedHosts[host.Name]; !expected {
|
||||||
|
t.Errorf("Unexpected host found: %s", host.Name)
|
||||||
|
} else {
|
||||||
|
expectedHosts[host.Name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that all expected hosts were found
|
||||||
|
for hostName, found := range expectedHosts {
|
||||||
|
if !found {
|
||||||
|
t.Errorf("Expected host %s not found", hostName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify host properties
|
||||||
|
for _, host := range hosts {
|
||||||
|
switch host.Name {
|
||||||
|
case "real-server":
|
||||||
|
if host.Hostname != "real.example.com" || host.User != "realuser" {
|
||||||
|
t.Errorf("real-server properties incorrect: hostname=%s, user=%s", host.Hostname, host.User)
|
||||||
|
}
|
||||||
|
case "another-real-server":
|
||||||
|
if host.Hostname != "another.example.com" || host.User != "anotheruser" {
|
||||||
|
t.Errorf("another-real-server properties incorrect: hostname=%s, user=%s", host.Hostname, host.User)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user