mirror of
https://github.com/Gu1llaum-3/sshm.git
synced 2026-01-27 03:04:21 +01:00
add ssh config include
This commit is contained in:
@@ -96,25 +96,45 @@ func ParseSSHConfig() ([]SSHHost, error) {
|
||||
|
||||
// ParseSSHConfigFile parses a specific SSH config file and returns the list of hosts
|
||||
func ParseSSHConfigFile(configPath string) ([]SSHHost, error) {
|
||||
return parseSSHConfigFileWithProcessedFiles(configPath, make(map[string]bool))
|
||||
}
|
||||
|
||||
// parseSSHConfigFileWithProcessedFiles parses SSH config with include support
|
||||
func parseSSHConfigFileWithProcessedFiles(configPath string, processedFiles map[string]bool) ([]SSHHost, error) {
|
||||
// Resolve absolute path to prevent infinite recursion
|
||||
absPath, err := filepath.Abs(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", configPath, err)
|
||||
}
|
||||
|
||||
// Check for circular includes
|
||||
if processedFiles[absPath] {
|
||||
return []SSHHost{}, nil // Skip already processed files silently
|
||||
}
|
||||
processedFiles[absPath] = true
|
||||
|
||||
// Check if the file exists, otherwise create it (and the parent directory if needed)
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
// Ensure .ssh directory exists with proper permissions
|
||||
if err := ensureSSHDirectory(); err != nil {
|
||||
return nil, fmt.Errorf("failed to create .ssh directory: %w", err)
|
||||
// Only create the main config file, not included files
|
||||
if absPath == getMainConfigPath() {
|
||||
// Ensure .ssh directory exists with proper permissions
|
||||
if err := ensureSSHDirectory(); err != nil {
|
||||
return nil, fmt.Errorf("failed to create .ssh directory: %w", err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(configPath, os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SSH config file: %w", err)
|
||||
}
|
||||
file.Close()
|
||||
|
||||
// Set secure permissions on the config file
|
||||
if err := SetSecureFilePermissions(configPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to set secure permissions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(configPath, os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SSH config file: %w", err)
|
||||
}
|
||||
file.Close()
|
||||
|
||||
// Set secure permissions on the config file
|
||||
if err := SetSecureFilePermissions(configPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to set secure permissions: %w", err)
|
||||
}
|
||||
|
||||
// File created, return empty host list
|
||||
// File doesn't exist, return empty host list
|
||||
return []SSHHost{}, nil
|
||||
}
|
||||
|
||||
@@ -168,11 +188,25 @@ func ParseSSHConfigFile(configPath string) ([]SSHHost, error) {
|
||||
value := strings.Join(parts[1:], " ")
|
||||
|
||||
switch key {
|
||||
case "include":
|
||||
// Handle Include directive
|
||||
includeHosts, err := processIncludeDirective(value, configPath, processedFiles)
|
||||
if err != nil {
|
||||
// Don't fail the entire parse if include fails, just skip it
|
||||
continue
|
||||
}
|
||||
hosts = append(hosts, includeHosts...)
|
||||
case "host":
|
||||
// New host, save previous one if it exists
|
||||
if currentHost != nil {
|
||||
hosts = append(hosts, *currentHost)
|
||||
}
|
||||
// Skip hosts with wildcards (*, ?) as they are typically patterns, not actual hosts
|
||||
if strings.ContainsAny(value, "*?") {
|
||||
currentHost = nil
|
||||
pendingTags = nil
|
||||
continue
|
||||
}
|
||||
// Create new host
|
||||
currentHost = &SSHHost{
|
||||
Name: value,
|
||||
@@ -222,6 +256,55 @@ func ParseSSHConfigFile(configPath string) ([]SSHHost, error) {
|
||||
return hosts, scanner.Err()
|
||||
}
|
||||
|
||||
// processIncludeDirective processes an Include directive and returns hosts from included files
|
||||
func processIncludeDirective(pattern string, baseConfigPath string, processedFiles map[string]bool) ([]SSHHost, error) {
|
||||
// Expand tilde to home directory
|
||||
if strings.HasPrefix(pattern, "~") {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
pattern = filepath.Join(homeDir, pattern[1:])
|
||||
}
|
||||
|
||||
// If pattern is not absolute, make it relative to the base config directory
|
||||
if !filepath.IsAbs(pattern) {
|
||||
baseDir := filepath.Dir(baseConfigPath)
|
||||
pattern = filepath.Join(baseDir, pattern)
|
||||
}
|
||||
|
||||
// Use glob to find matching files
|
||||
matches, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to glob pattern %s: %w", pattern, err)
|
||||
}
|
||||
|
||||
var allHosts []SSHHost
|
||||
for _, match := range matches {
|
||||
// Skip directories
|
||||
if info, err := os.Stat(match); err == nil && info.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Recursively parse the included file
|
||||
hosts, err := parseSSHConfigFileWithProcessedFiles(match, processedFiles)
|
||||
if err != nil {
|
||||
// Skip files that can't be parsed rather than failing completely
|
||||
continue
|
||||
}
|
||||
allHosts = append(allHosts, hosts...)
|
||||
}
|
||||
|
||||
return allHosts, nil
|
||||
}
|
||||
|
||||
// getMainConfigPath returns the main SSH config path for comparison
|
||||
func getMainConfigPath() string {
|
||||
configPath, _ := GetDefaultSSHConfigPath()
|
||||
absPath, _ := filepath.Abs(configPath)
|
||||
return absPath
|
||||
}
|
||||
|
||||
// AddSSHHost adds a new SSH host to the config file
|
||||
func AddSSHHost(host SSHHost) error {
|
||||
configPath, err := GetDefaultSSHConfigPath()
|
||||
|
||||
Reference in New Issue
Block a user