mirror of
https://github.com/Gu1llaum-3/sshm.git
synced 2026-03-14 03:41:27 +01:00
fix: enable editing and management of hosts from included SSH config files
• Add SourceFile field to SSHHost struct to track config file origins • Implement FindHostInAllConfigs() to locate hosts across all config files • Fix "host not found" errors when editing/deleting hosts from included files • Add GetAllConfigFiles() and GetAllConfigFilesFromBase() for config discovery • Create UpdateSSHHostV2() and DeleteSSHHostV2() for cross-file operations • Display config file source in edit and info forms for better visibility • Add intelligent file selector for host addition when multiple configs exist • Support -c parameter context with proper file resolution • Exclude .backup files from Include directive processing • Maintain backward compatibility with existing SSH config workflows Resolves limitation where hosts from included config files could be viewed but not edited, deleted, or properly managed through the interface.
This commit is contained in:
@@ -13,14 +13,15 @@ import (
|
||||
|
||||
// SSHHost represents an SSH host configuration
|
||||
type SSHHost struct {
|
||||
Name string
|
||||
Hostname string
|
||||
User string
|
||||
Port string
|
||||
Identity string
|
||||
ProxyJump string
|
||||
Options string
|
||||
Tags []string
|
||||
Name string
|
||||
Hostname string
|
||||
User string
|
||||
Port string
|
||||
Identity string
|
||||
ProxyJump string
|
||||
Options string
|
||||
Tags []string
|
||||
SourceFile string // Path to the config file where this host is defined
|
||||
}
|
||||
|
||||
// GetDefaultSSHConfigPath returns the default SSH config path for the current platform
|
||||
@@ -209,9 +210,10 @@ func parseSSHConfigFileWithProcessedFiles(configPath string, processedFiles map[
|
||||
}
|
||||
// Create new host
|
||||
currentHost = &SSHHost{
|
||||
Name: value,
|
||||
Port: "22", // Default port
|
||||
Tags: pendingTags, // Assign pending tags to this host
|
||||
Name: value,
|
||||
Port: "22", // Default port
|
||||
Tags: pendingTags, // Assign pending tags to this host
|
||||
SourceFile: absPath, // Track which file this host comes from
|
||||
}
|
||||
// Clear pending tags for next host
|
||||
pendingTags = nil
|
||||
@@ -286,6 +288,16 @@ func processIncludeDirective(pattern string, baseConfigPath string, processedFil
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip backup files created by sshm (*.backup)
|
||||
if strings.HasSuffix(match, ".backup") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip markdown files (*.md)
|
||||
if strings.HasSuffix(match, ".md") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Recursively parse the included file
|
||||
hosts, err := parseSSHConfigFileWithProcessedFiles(match, processedFiles)
|
||||
if err != nil {
|
||||
@@ -529,11 +541,7 @@ func GetSSHHostFromFile(hostName string, configPath string) (*SSHHost, error) {
|
||||
|
||||
// UpdateSSHHost updates an existing SSH host configuration
|
||||
func UpdateSSHHost(oldName string, newHost SSHHost) error {
|
||||
configPath, err := GetDefaultSSHConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return UpdateSSHHostInFile(oldName, newHost, configPath)
|
||||
return UpdateSSHHostV2(oldName, newHost)
|
||||
}
|
||||
|
||||
// UpdateSSHHostInFile updates an existing SSH host configuration in a specific file
|
||||
@@ -688,11 +696,7 @@ func UpdateSSHHostInFile(oldName string, newHost SSHHost, configPath string) err
|
||||
|
||||
// DeleteSSHHost removes an SSH host configuration from the config file
|
||||
func DeleteSSHHost(hostName string) error {
|
||||
configPath, err := GetDefaultSSHConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return DeleteSSHHostFromFile(hostName, configPath)
|
||||
return DeleteSSHHostV2(hostName)
|
||||
}
|
||||
|
||||
// DeleteSSHHostFromFile deletes an SSH host from a specific config file
|
||||
@@ -776,3 +780,115 @@ func DeleteSSHHostFromFile(hostName, configPath string) error {
|
||||
newContent := strings.Join(newLines, "\n")
|
||||
return os.WriteFile(configPath, []byte(newContent), 0600)
|
||||
}
|
||||
|
||||
// FindHostInAllConfigs finds a host in all configuration files and returns the host with its source file
|
||||
func FindHostInAllConfigs(hostName string) (*SSHHost, error) {
|
||||
hosts, err := ParseSSHConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
if host.Name == hostName {
|
||||
return &host, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("host '%s' not found in any configuration file", hostName)
|
||||
}
|
||||
|
||||
// GetAllConfigFiles returns all SSH config files (main + included files)
|
||||
func GetAllConfigFiles() ([]string, error) {
|
||||
configPath, err := GetDefaultSSHConfigPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processedFiles := make(map[string]bool)
|
||||
_, _ = parseSSHConfigFileWithProcessedFiles(configPath, processedFiles)
|
||||
|
||||
files := make([]string, 0, len(processedFiles))
|
||||
for file := range processedFiles {
|
||||
files = append(files, file)
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// GetAllConfigFilesFromBase returns all SSH config files starting from a specific base config file
|
||||
func GetAllConfigFilesFromBase(baseConfigPath string) ([]string, error) {
|
||||
if baseConfigPath == "" {
|
||||
// Fallback to default behavior
|
||||
return GetAllConfigFiles()
|
||||
}
|
||||
|
||||
processedFiles := make(map[string]bool)
|
||||
_, _ = parseSSHConfigFileWithProcessedFiles(baseConfigPath, processedFiles)
|
||||
|
||||
files := make([]string, 0, len(processedFiles))
|
||||
for file := range processedFiles {
|
||||
files = append(files, file)
|
||||
}
|
||||
|
||||
return files, nil
|
||||
} // UpdateSSHHostV2 updates an existing SSH host configuration, searching in all config files
|
||||
func UpdateSSHHostV2(oldName string, newHost SSHHost) error {
|
||||
// Find the host to determine which file it's in
|
||||
existingHost, err := FindHostInAllConfigs(oldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the host in its source file
|
||||
newHost.SourceFile = existingHost.SourceFile
|
||||
return UpdateSSHHostInFile(oldName, newHost, existingHost.SourceFile)
|
||||
}
|
||||
|
||||
// DeleteSSHHostV2 removes an SSH host configuration, searching in all config files
|
||||
func DeleteSSHHostV2(hostName string) error {
|
||||
// Find the host to determine which file it's in
|
||||
existingHost, err := FindHostInAllConfigs(hostName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete the host from its source file
|
||||
return DeleteSSHHostFromFile(hostName, existingHost.SourceFile)
|
||||
}
|
||||
|
||||
// AddSSHHostWithFileSelection adds a new SSH host to a user-specified config file
|
||||
func AddSSHHostWithFileSelection(host SSHHost, targetFile string) error {
|
||||
if targetFile == "" {
|
||||
// Use default file if none specified
|
||||
return AddSSHHost(host)
|
||||
}
|
||||
return AddSSHHostToFile(host, targetFile)
|
||||
}
|
||||
|
||||
// GetIncludedConfigFiles returns a list of config files that can be used for adding hosts
|
||||
func GetIncludedConfigFiles() ([]string, error) {
|
||||
allFiles, err := GetAllConfigFiles()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Filter out files that don't exist or can't be written to
|
||||
var writableFiles []string
|
||||
mainConfig, err := GetDefaultSSHConfigPath()
|
||||
if err == nil {
|
||||
writableFiles = append(writableFiles, mainConfig)
|
||||
}
|
||||
|
||||
for _, file := range allFiles {
|
||||
if file == mainConfig {
|
||||
continue // Already added
|
||||
}
|
||||
|
||||
// Check if file exists and is writable
|
||||
if info, err := os.Stat(file); err == nil && !info.IsDir() {
|
||||
writableFiles = append(writableFiles, file)
|
||||
}
|
||||
}
|
||||
|
||||
return writableFiles, nil
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func TestEnsureSSHDirectory(t *testing.T) {
|
||||
func TestParseSSHConfigWithInclude(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
|
||||
// Create main config file
|
||||
mainConfig := filepath.Join(tempDir, "config")
|
||||
mainConfigContent := `Host main-host
|
||||
@@ -90,7 +90,7 @@ Host another-host
|
||||
HostName another.example.com
|
||||
User anotheruser
|
||||
`
|
||||
|
||||
|
||||
err := os.WriteFile(mainConfig, []byte(mainConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create main config: %v", err)
|
||||
@@ -103,7 +103,7 @@ Host another-host
|
||||
User includeduser
|
||||
Port 2222
|
||||
`
|
||||
|
||||
|
||||
err = os.WriteFile(includedConfig, []byte(includedConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create included config: %v", err)
|
||||
@@ -122,7 +122,7 @@ Host another-host
|
||||
User subuser
|
||||
IdentityFile ~/.ssh/sub_key
|
||||
`
|
||||
|
||||
|
||||
err = os.WriteFile(subConfig, []byte(subConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sub config: %v", err)
|
||||
@@ -158,18 +158,30 @@ Host another-host
|
||||
if host.Hostname != "example.com" || host.User != "mainuser" {
|
||||
t.Errorf("main-host properties incorrect: hostname=%s, user=%s", host.Hostname, host.User)
|
||||
}
|
||||
if host.SourceFile != mainConfig {
|
||||
t.Errorf("main-host SourceFile incorrect: expected=%s, got=%s", mainConfig, host.SourceFile)
|
||||
}
|
||||
case "included-host":
|
||||
if host.Hostname != "included.example.com" || host.User != "includeduser" || host.Port != "2222" {
|
||||
t.Errorf("included-host properties incorrect: hostname=%s, user=%s, port=%s", host.Hostname, host.User, host.Port)
|
||||
}
|
||||
if host.SourceFile != includedConfig {
|
||||
t.Errorf("included-host SourceFile incorrect: expected=%s, got=%s", includedConfig, host.SourceFile)
|
||||
}
|
||||
case "sub-host":
|
||||
if host.Hostname != "sub.example.com" || host.User != "subuser" || host.Identity != "~/.ssh/sub_key" {
|
||||
t.Errorf("sub-host properties incorrect: hostname=%s, user=%s, identity=%s", host.Hostname, host.User, host.Identity)
|
||||
}
|
||||
if host.SourceFile != subConfig {
|
||||
t.Errorf("sub-host SourceFile incorrect: expected=%s, got=%s", subConfig, host.SourceFile)
|
||||
}
|
||||
case "another-host":
|
||||
if host.Hostname != "another.example.com" || host.User != "anotheruser" {
|
||||
t.Errorf("another-host properties incorrect: hostname=%s, user=%s", host.Hostname, host.User)
|
||||
}
|
||||
if host.SourceFile != mainConfig {
|
||||
t.Errorf("another-host SourceFile incorrect: expected=%s, got=%s", mainConfig, host.SourceFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,7 +198,7 @@ Host another-host
|
||||
func TestParseSSHConfigWithCircularInclude(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
|
||||
// Create config1 that includes config2
|
||||
config1 := filepath.Join(tempDir, "config1")
|
||||
config1Content := `Host host1
|
||||
@@ -194,7 +206,7 @@ func TestParseSSHConfigWithCircularInclude(t *testing.T) {
|
||||
|
||||
Include config2
|
||||
`
|
||||
|
||||
|
||||
err := os.WriteFile(config1, []byte(config1Content), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create config1: %v", err)
|
||||
@@ -207,7 +219,7 @@ Include config2
|
||||
|
||||
Include config1
|
||||
`
|
||||
|
||||
|
||||
err = os.WriteFile(config2, []byte(config2Content), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create config2: %v", err)
|
||||
@@ -247,7 +259,7 @@ Include config1
|
||||
func TestParseSSHConfigWithNonExistentInclude(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
|
||||
// Create main config file with non-existent include
|
||||
mainConfig := filepath.Join(tempDir, "config")
|
||||
mainConfigContent := `Host main-host
|
||||
@@ -258,7 +270,7 @@ Include non-existent-file.conf
|
||||
Host another-host
|
||||
HostName another.example.com
|
||||
`
|
||||
|
||||
|
||||
err := os.WriteFile(mainConfig, []byte(mainConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create main config: %v", err)
|
||||
@@ -288,7 +300,7 @@ Host another-host
|
||||
func TestParseSSHConfigWithWildcardHosts(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
|
||||
// Create config file with wildcard hosts
|
||||
configFile := filepath.Join(tempDir, "config")
|
||||
configContent := `# Wildcard patterns should be ignored
|
||||
@@ -311,7 +323,7 @@ Host another-real-server
|
||||
HostName another.example.com
|
||||
User anotheruser
|
||||
`
|
||||
|
||||
|
||||
err := os.WriteFile(configFile, []byte(configContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create config: %v", err)
|
||||
@@ -365,3 +377,323 @@ Host another-real-server
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSHConfigExcludesBackupFiles(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create main config file with include pattern
|
||||
mainConfig := filepath.Join(tempDir, "config")
|
||||
mainConfigContent := `Host main-host
|
||||
HostName example.com
|
||||
|
||||
Include *.conf
|
||||
`
|
||||
|
||||
err := os.WriteFile(mainConfig, []byte(mainConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create main config: %v", err)
|
||||
}
|
||||
|
||||
// Create a regular config file
|
||||
regularConfig := filepath.Join(tempDir, "regular.conf")
|
||||
regularConfigContent := `Host regular-host
|
||||
HostName regular.example.com
|
||||
`
|
||||
|
||||
err = os.WriteFile(regularConfig, []byte(regularConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create regular config: %v", err)
|
||||
}
|
||||
|
||||
// Create a backup file that should be excluded
|
||||
backupConfig := filepath.Join(tempDir, "regular.conf.backup")
|
||||
backupConfigContent := `Host backup-host
|
||||
HostName backup.example.com
|
||||
`
|
||||
|
||||
err = os.WriteFile(backupConfig, []byte(backupConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create backup config: %v", err)
|
||||
}
|
||||
|
||||
// Parse the config file
|
||||
hosts, err := ParseSSHConfigFile(mainConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseSSHConfigFile() error = %v", err)
|
||||
}
|
||||
|
||||
// Should only get main-host and regular-host, not backup-host
|
||||
expectedHosts := map[string]bool{
|
||||
"main-host": false,
|
||||
"regular-host": false,
|
||||
}
|
||||
|
||||
if len(hosts) != len(expectedHosts) {
|
||||
t.Errorf("Expected %d hosts, got %d", len(expectedHosts), len(hosts))
|
||||
for _, host := range hosts {
|
||||
t.Logf("Found host: %s", host.Name)
|
||||
}
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
if _, expected := expectedHosts[host.Name]; !expected {
|
||||
t.Errorf("Unexpected host found: %s (backup files should be excluded)", host.Name)
|
||||
} else {
|
||||
expectedHosts[host.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check that backup-host was not included
|
||||
for _, host := range hosts {
|
||||
if host.Name == "backup-host" {
|
||||
t.Error("backup-host should not be included (backup files should be excluded)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindHostInAllConfigs(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create main config file
|
||||
mainConfig := filepath.Join(tempDir, "config")
|
||||
mainConfigContent := `Host main-host
|
||||
HostName example.com
|
||||
|
||||
Include included.conf
|
||||
`
|
||||
|
||||
err := os.WriteFile(mainConfig, []byte(mainConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create main config: %v", err)
|
||||
}
|
||||
|
||||
// Create included config file
|
||||
includedConfig := filepath.Join(tempDir, "included.conf")
|
||||
includedConfigContent := `Host included-host
|
||||
HostName included.example.com
|
||||
User includeduser
|
||||
`
|
||||
|
||||
err = os.WriteFile(includedConfig, []byte(includedConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create included config: %v", err)
|
||||
}
|
||||
|
||||
// Test finding host from main config
|
||||
host, err := GetSSHHostFromFile("main-host", mainConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSSHHostFromFile() error = %v", err)
|
||||
}
|
||||
if host.Name != "main-host" || host.Hostname != "example.com" {
|
||||
t.Errorf("main-host not found correctly: name=%s, hostname=%s", host.Name, host.Hostname)
|
||||
}
|
||||
if host.SourceFile != mainConfig {
|
||||
t.Errorf("main-host SourceFile incorrect: expected=%s, got=%s", mainConfig, host.SourceFile)
|
||||
}
|
||||
|
||||
// Test finding host from included config
|
||||
// Note: This tests the full parsing with includes
|
||||
hosts, err := ParseSSHConfigFile(mainConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseSSHConfigFile() error = %v", err)
|
||||
}
|
||||
|
||||
var includedHost *SSHHost
|
||||
for _, h := range hosts {
|
||||
if h.Name == "included-host" {
|
||||
includedHost = &h
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if includedHost == nil {
|
||||
t.Fatal("included-host not found")
|
||||
}
|
||||
if includedHost.Hostname != "included.example.com" || includedHost.User != "includeduser" {
|
||||
t.Errorf("included-host properties incorrect: hostname=%s, user=%s", includedHost.Hostname, includedHost.User)
|
||||
}
|
||||
if includedHost.SourceFile != includedConfig {
|
||||
t.Errorf("included-host SourceFile incorrect: expected=%s, got=%s", includedConfig, includedHost.SourceFile)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllConfigFiles(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create main config file
|
||||
mainConfig := filepath.Join(tempDir, "config")
|
||||
mainConfigContent := `Host main-host
|
||||
HostName example.com
|
||||
|
||||
Include included.conf
|
||||
Include subdir/*.conf
|
||||
`
|
||||
|
||||
err := os.WriteFile(mainConfig, []byte(mainConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create main config: %v", err)
|
||||
}
|
||||
|
||||
// Create included config file
|
||||
includedConfig := filepath.Join(tempDir, "included.conf")
|
||||
err = os.WriteFile(includedConfig, []byte("Host included-host\n HostName included.example.com\n"), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create included config: %v", err)
|
||||
}
|
||||
|
||||
// Create subdirectory with config files
|
||||
subDir := filepath.Join(tempDir, "subdir")
|
||||
err = os.MkdirAll(subDir, 0700)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subdir: %v", err)
|
||||
}
|
||||
|
||||
subConfig := filepath.Join(subDir, "sub.conf")
|
||||
err = os.WriteFile(subConfig, []byte("Host sub-host\n HostName sub.example.com\n"), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sub config: %v", err)
|
||||
}
|
||||
|
||||
// Parse to populate the processed files map
|
||||
_, err = ParseSSHConfigFile(mainConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseSSHConfigFile() error = %v", err)
|
||||
}
|
||||
|
||||
// Note: GetAllConfigFiles() uses a fresh parse, so we test it indirectly
|
||||
// by checking that all files are found during parsing
|
||||
hosts, err := ParseSSHConfigFile(mainConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseSSHConfigFile() error = %v", err)
|
||||
}
|
||||
|
||||
// Check that hosts from all files are found
|
||||
sourceFiles := make(map[string]bool)
|
||||
for _, host := range hosts {
|
||||
sourceFiles[host.SourceFile] = true
|
||||
}
|
||||
|
||||
expectedFiles := []string{mainConfig, includedConfig, subConfig}
|
||||
for _, expectedFile := range expectedFiles {
|
||||
if !sourceFiles[expectedFile] {
|
||||
t.Errorf("Expected config file not found in SourceFile: %s", expectedFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllConfigFilesFromBase(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create main config file
|
||||
mainConfig := filepath.Join(tempDir, "config")
|
||||
mainConfigContent := `Host main-host
|
||||
HostName example.com
|
||||
|
||||
Include included.conf
|
||||
`
|
||||
|
||||
err := os.WriteFile(mainConfig, []byte(mainConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create main config: %v", err)
|
||||
}
|
||||
|
||||
// Create included config file
|
||||
includedConfig := filepath.Join(tempDir, "included.conf")
|
||||
includedConfigContent := `Host included-host
|
||||
HostName included.example.com
|
||||
|
||||
Include subdir/*.conf
|
||||
`
|
||||
|
||||
err = os.WriteFile(includedConfig, []byte(includedConfigContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create included config: %v", err)
|
||||
}
|
||||
|
||||
// Create subdirectory with config files
|
||||
subDir := filepath.Join(tempDir, "subdir")
|
||||
err = os.MkdirAll(subDir, 0700)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subdir: %v", err)
|
||||
}
|
||||
|
||||
subConfig := filepath.Join(subDir, "sub.conf")
|
||||
err = os.WriteFile(subConfig, []byte("Host sub-host\n HostName sub.example.com\n"), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sub config: %v", err)
|
||||
}
|
||||
|
||||
// Create an isolated config file that should not be included
|
||||
isolatedConfig := filepath.Join(tempDir, "isolated.conf")
|
||||
err = os.WriteFile(isolatedConfig, []byte("Host isolated-host\n HostName isolated.example.com\n"), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create isolated config: %v", err)
|
||||
}
|
||||
|
||||
// Test GetAllConfigFilesFromBase with main config as base
|
||||
files, err := GetAllConfigFilesFromBase(mainConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAllConfigFilesFromBase() error = %v", err)
|
||||
}
|
||||
|
||||
// Should find main config, included config, and sub config, but not isolated config
|
||||
expectedFiles := map[string]bool{
|
||||
mainConfig: false,
|
||||
includedConfig: false,
|
||||
subConfig: false,
|
||||
}
|
||||
|
||||
if len(files) != len(expectedFiles) {
|
||||
t.Errorf("Expected %d config files, got %d", len(expectedFiles), len(files))
|
||||
for i, file := range files {
|
||||
t.Logf("Found file %d: %s", i+1, file)
|
||||
}
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if _, expected := expectedFiles[file]; expected {
|
||||
expectedFiles[file] = true
|
||||
} else if file == isolatedConfig {
|
||||
t.Errorf("Isolated config file should not be included: %s", file)
|
||||
} else {
|
||||
t.Logf("Unexpected file found: %s", file)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all expected files were found
|
||||
for file, found := range expectedFiles {
|
||||
if !found {
|
||||
t.Errorf("Expected config file not found: %s", file)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetAllConfigFilesFromBase with isolated config as base (should only return itself)
|
||||
isolatedFiles, err := GetAllConfigFilesFromBase(isolatedConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAllConfigFilesFromBase() error = %v", err)
|
||||
}
|
||||
|
||||
if len(isolatedFiles) != 1 || isolatedFiles[0] != isolatedConfig {
|
||||
t.Errorf("Expected only isolated config file, got: %v", isolatedFiles)
|
||||
}
|
||||
|
||||
// Test with empty base config file path (should fallback to default behavior)
|
||||
defaultFiles, err := GetAllConfigFilesFromBase("")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAllConfigFilesFromBase('') error = %v", err)
|
||||
}
|
||||
|
||||
// Should behave like GetAllConfigFiles()
|
||||
allFiles, err := GetAllConfigFiles()
|
||||
if err != nil {
|
||||
t.Fatalf("GetAllConfigFiles() error = %v", err)
|
||||
}
|
||||
|
||||
if len(defaultFiles) != len(allFiles) {
|
||||
t.Errorf("GetAllConfigFilesFromBase('') should behave like GetAllConfigFiles(). Got %d vs %d files", len(defaultFiles), len(allFiles))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user