mirror of
https://github.com/Gu1llaum-3/sshm.git
synced 2026-01-27 03:04:21 +01:00
feat: add Windows platform support
This commit is contained in:
11
internal/config/permissions_unix.go
Normal file
11
internal/config/permissions_unix.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !windows
|
||||
|
||||
package config
|
||||
|
||||
import "os"
|
||||
|
||||
// SetSecureFilePermissions configures secure permissions on Unix systems
|
||||
func SetSecureFilePermissions(filepath string) error {
|
||||
// Set file permissions to 0600 (owner read/write only)
|
||||
return os.Chmod(filepath, 0600)
|
||||
}
|
||||
24
internal/config/permissions_windows.go
Normal file
24
internal/config/permissions_windows.go
Normal file
@@ -0,0 +1,24 @@
|
||||
//go:build windows
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// SetSecureFilePermissions configures secure permissions on Windows
|
||||
func SetSecureFilePermissions(filepath string) error {
|
||||
// On Windows, file permissions work differently
|
||||
// We ensure the file is not read-only and has basic permissions
|
||||
info, err := os.Stat(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the file is not read-only
|
||||
if info.Mode()&os.ModeType == 0 {
|
||||
return os.Chmod(filepath, 0600)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
@@ -22,6 +23,46 @@ type SSHHost struct {
|
||||
Tags []string
|
||||
}
|
||||
|
||||
// GetDefaultSSHConfigPath returns the default SSH config path for the current platform
|
||||
func GetDefaultSSHConfigPath() (string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return filepath.Join(homeDir, ".ssh", "config"), nil
|
||||
default:
|
||||
// Linux, macOS, etc.
|
||||
return filepath.Join(homeDir, ".ssh", "config"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetSSHDirectory returns the .ssh directory path
|
||||
func GetSSHDirectory() (string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(homeDir, ".ssh"), nil
|
||||
}
|
||||
|
||||
// ensureSSHDirectory creates the .ssh directory with appropriate permissions
|
||||
func ensureSSHDirectory() error {
|
||||
sshDir, err := GetSSHDirectory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := os.Stat(sshDir); os.IsNotExist(err) {
|
||||
// 0700 provides owner-only access across platforms
|
||||
return os.MkdirAll(sshDir, 0700)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// configMutex protects SSH config file operations from race conditions
|
||||
var configMutex sync.Mutex
|
||||
|
||||
@@ -46,12 +87,10 @@ func backupConfig(configPath string) error {
|
||||
|
||||
// ParseSSHConfig parses the SSH config file and returns the list of hosts
|
||||
func ParseSSHConfig() ([]SSHHost, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
configPath, err := GetDefaultSSHConfigPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(homeDir, ".ssh", "config")
|
||||
return ParseSSHConfigFile(configPath)
|
||||
}
|
||||
|
||||
@@ -59,18 +98,22 @@ func ParseSSHConfig() ([]SSHHost, error) {
|
||||
func ParseSSHConfigFile(configPath string) ([]SSHHost, error) {
|
||||
// Check if the file exists, otherwise create it (and the parent directory if needed)
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
dir := filepath.Dir(configPath)
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
err = os.MkdirAll(dir, 0700)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create .ssh directory: %w", err)
|
||||
}
|
||||
// Ensure .ssh directory exists with proper permissions
|
||||
if err := ensureSSHDirectory(); err != nil {
|
||||
return nil, fmt.Errorf("failed to create .ssh directory: %w", err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(configPath, os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SSH config file: %w", err)
|
||||
}
|
||||
file.Close()
|
||||
|
||||
// Set secure permissions on the config file
|
||||
if err := SetSecureFilePermissions(configPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to set secure permissions: %w", err)
|
||||
}
|
||||
|
||||
// File created, return empty host list
|
||||
return []SSHHost{}, nil
|
||||
}
|
||||
@@ -181,11 +224,10 @@ func ParseSSHConfigFile(configPath string) ([]SSHHost, error) {
|
||||
|
||||
// AddSSHHost adds a new SSH host to the config file
|
||||
func AddSSHHost(host SSHHost) error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
configPath, err := GetDefaultSSHConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
configPath := filepath.Join(homeDir, ".ssh", "config")
|
||||
return AddSSHHostToFile(host, configPath)
|
||||
}
|
||||
|
||||
@@ -404,11 +446,10 @@ func GetSSHHostFromFile(hostName string, configPath string) (*SSHHost, error) {
|
||||
|
||||
// UpdateSSHHost updates an existing SSH host configuration
|
||||
func UpdateSSHHost(oldName string, newHost SSHHost) error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
configPath, err := GetDefaultSSHConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
configPath := filepath.Join(homeDir, ".ssh", "config")
|
||||
return UpdateSSHHostInFile(oldName, newHost, configPath)
|
||||
}
|
||||
|
||||
@@ -564,11 +605,10 @@ func UpdateSSHHostInFile(oldName string, newHost SSHHost, configPath string) err
|
||||
|
||||
// DeleteSSHHost removes an SSH host configuration from the config file
|
||||
func DeleteSSHHost(hostName string) error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
configPath, err := GetDefaultSSHConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
configPath := filepath.Join(homeDir, ".ssh", "config")
|
||||
return DeleteSSHHostFromFile(hostName, configPath)
|
||||
}
|
||||
|
||||
|
||||
73
internal/config/ssh_test.go
Normal file
73
internal/config/ssh_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDefaultSSHConfigPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
goos string
|
||||
expected string
|
||||
}{
|
||||
{"Linux", "linux", ".ssh/config"},
|
||||
{"macOS", "darwin", ".ssh/config"},
|
||||
{"Windows", "windows", ".ssh/config"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Save original GOOS
|
||||
originalGOOS := runtime.GOOS
|
||||
defer func() {
|
||||
// Note: We can't actually change runtime.GOOS at runtime
|
||||
// This test verifies the function logic with the current OS
|
||||
_ = originalGOOS
|
||||
}()
|
||||
|
||||
configPath, err := GetDefaultSSHConfigPath()
|
||||
if err != nil {
|
||||
t.Fatalf("GetDefaultSSHConfigPath() error = %v", err)
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(configPath, tt.expected) {
|
||||
t.Errorf("Expected path to end with %q, got %q", tt.expected, configPath)
|
||||
}
|
||||
|
||||
// Verify the path uses the correct separator for current OS
|
||||
expectedSeparator := string(filepath.Separator)
|
||||
if !strings.Contains(configPath, expectedSeparator) && len(configPath) > len(tt.expected) {
|
||||
t.Errorf("Path should use OS-specific separator %q, got %q", expectedSeparator, configPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSSHDirectory(t *testing.T) {
|
||||
sshDir, err := GetSSHDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("GetSSHDirectory() error = %v", err)
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(sshDir, ".ssh") {
|
||||
t.Errorf("Expected directory to end with .ssh, got %q", sshDir)
|
||||
}
|
||||
|
||||
// Verify the path uses the correct separator for current OS
|
||||
expectedSeparator := string(filepath.Separator)
|
||||
if !strings.Contains(sshDir, expectedSeparator) && len(sshDir) > 4 {
|
||||
t.Errorf("Path should use OS-specific separator %q, got %q", expectedSeparator, sshDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureSSHDirectory(t *testing.T) {
|
||||
// This test just ensures the function doesn't panic
|
||||
// and returns without error when .ssh directory already exists
|
||||
err := ensureSSHDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("ensureSSHDirectory() error = %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user