7 Commits

Author SHA1 Message Date
838941e3eb fix: allow env vars and SSH tokens in IdentityFile validation (issue #33)
ValidateIdentityFile now accepts $VAR/${VAR} (expanded via os.Expand, undefined vars accepted as-is) and SSH tokens like %d, %h before falling back to os.Stat.
The raw value is preserved when writing to ssh_config.
2026-02-23 23:04:40 +01:00
2a1f6d5449 fix: replace sshm process with ssh via syscall.Exec (issue #41)
When running `sshm <host>`, the sshm process was staying alive as a parent for the entire SSH session.
History is recorded before SSH starts, so the parent process served no purpose.

Use syscall.Exec() to replace the sshm process in-place with ssh, keeping the same PID. Falls back to exec.Command() on Windows where syscall.Exec is not supported.
2026-02-23 21:48:41 +01:00
f189cb37e3 feat: add --no-update-check flag and disable update check via config (issue #23)
Add support for disabling the automatic update check at startup, which could cause delays on air-gapped or offline machines due to DNS timeouts.

- Add --no-update-check CLI flag for one-time override
- Add check_for_updates field (*bool) to AppConfig with default true
- CLI flag overrides the config file setting (both feed into IsUpdateCheckEnabled)
- Move update check from --version template to TUI Init() only, respecting the new configuration
- Remove getVersionWithUpdateCheck() from cmd/root.go; --version now prints a plain version string
- Rename internal/config/keybindings.go → appconfig.go and keybindings_test.go → appconfig_test.go to reflect the broader scope of the file
- Add TestIsUpdateCheckEnabled with table-driven cases (nil config, nil field, true, false) and extend existing integration test with a CheckForUpdates round-trip
- Update README: document --no-update-check flag, config option, and rename "Custom Key Bindings" section to "Application Configuration"
2026-02-23 21:28:54 +01:00
Guillaume Archambault
891fb2a0f4 Merge pull request #44 from fgbm/main
Fix: connectivity check for hosts using ProxyJump or ProxyCommand
2026-02-22 12:24:42 +01:00
Guillaume Archambault
473b1b6063 Merge pull request #40 from boxpositron/feat/info-command
feat: add info command for JSON host details
2026-02-22 12:19:38 +01:00
Vladislav Chmelyuk
5d0c0ffcf3 refactor: update NewPingManager to accept a config file parameter
- Modified the NewPingManager function to include a configFile argument for better SSH configuration management.
- Updated all relevant tests to reflect the new function signature.
- Enhanced ping functionality to support ProxyJump and ProxyCommand using an external SSH command.
- Adjusted UI initialization to pass the config file to the PingManager.

This change improves flexibility in managing SSH connections and enhances the overall functionality of the ping manager.
2026-02-04 14:17:28 +03:00
David Ibia
7d9b794ceb feat: add info command for JSON host details
Adds a jq-friendly `sshm info` subcommand with host completion and documentation, and makes home directory resolution testable for backup path tests.
2026-01-12 23:53:35 +01:00
15 changed files with 820 additions and 79 deletions

View File

@@ -268,13 +268,40 @@ sshm move my-server -c /path/to/custom/ssh_config
# Search for hosts (interactive filter) # Search for hosts (interactive filter)
sshm search sshm search
# Show version information (includes update check) # Print machine-readable info (JSON) for scripting
sshm info prod-server
sshm info prod-server --pretty
# With a custom SSH config file
sshm -c /path/to/custom/ssh_config info prod-server
# Pipe to jq
sshm info prod-server | jq -r '.result.target.hostname'
sshm info prod-server | jq -r '.result.target.user'
# Show version information
sshm --version sshm --version
# Disable automatic update check (useful on air-gapped machines)
sshm --no-update-check
# Show help and available commands # Show help and available commands
sshm --help sshm --help
``` ```
### Host Info (JSON)
`sshm info <hostname>` prints a single JSON object to stdout so you can script against it with `jq`.
```bash
# Extract fields
sshm info prod-server | jq -r '.result.target.hostname'
sshm info prod-server | jq -r '.result.target.port'
# Check not-found (exit code 2)
sshm info does-not-exist | jq -r '.error.code'
```
### Shell Completion ### Shell Completion
SSHM supports shell completion for host names, making it easy to connect to hosts without typing full names: SSHM supports shell completion for host names, making it easy to connect to hosts without typing full names:
@@ -475,17 +502,31 @@ SSHM features asynchronous SSH connectivity checking that provides visual indica
SSHM includes built-in version checking that notifies you of available updates: SSHM includes built-in version checking that notifies you of available updates:
**Features:** **Features:**
- **Background checking** - Version check happens asynchronously - **Background checking** - Version check happens asynchronously, never blocking startup
- **Release notifications** - Clear indicators when updates are available - **Release notifications** - Clear indicators when updates are available
- **Pre-release detection** - Identifies beta and development versions - **Pre-release detection** - Identifies beta and development versions
- **GitHub integration** - Direct links to release pages - **GitHub integration** - Direct links to release pages
- **Non-intrusive** - Updates don't interrupt your workflow - **Non-intrusive** - Updates don't interrupt your workflow
- **Configurable** - Can be disabled for air-gapped or offline environments
**Update notifications appear:** **Update notifications appear:**
- In the main TUI interface as a subtle notification - In the main TUI interface as a subtle notification
- In the `sshm --version` command output
- Only when a newer stable version is available - Only when a newer stable version is available
**Disabling update checks:**
Via the CLI flag (one-time):
```bash
sshm --no-update-check
```
Via `~/.config/sshm/config.json` (persistent):
```json
{
"check_for_updates": false
}
```
#### Port Forwarding History #### Port Forwarding History
SSHM remembers your port forwarding configurations for easy reuse: SSHM remembers your port forwarding configurations for easy reuse:
@@ -639,9 +680,9 @@ This will be automatically converted to:
StrictHostKeyChecking no StrictHostKeyChecking no
``` ```
### Custom Key Bindings ### Application Configuration
SSHM supports customizable key bindings through a configuration file. This is particularly useful for users who want to modify the default quit behavior. SSHM supports a configuration file to customize its behavior, including key bindings and update checking.
**Configuration File Location:** **Configuration File Location:**
- **Linux/macOS**: `~/.config/sshm/config.json` - **Linux/macOS**: `~/.config/sshm/config.json`
@@ -650,6 +691,7 @@ SSHM supports customizable key bindings through a configuration file. This is pa
**Example Configuration:** **Example Configuration:**
```json ```json
{ {
"check_for_updates": false,
"key_bindings": { "key_bindings": {
"quit_keys": ["q", "ctrl+c"], "quit_keys": ["q", "ctrl+c"],
"disable_esc_quit": true "disable_esc_quit": true
@@ -658,12 +700,16 @@ SSHM supports customizable key bindings through a configuration file. This is pa
``` ```
**Available Options:** **Available Options:**
- **check_for_updates**: Boolean to enable or disable the automatic update check at startup. Default: `true`. Set to `false` on air-gapped or offline machines to avoid connection delays.
- **quit_keys**: Array of keys that will quit the application. Default: `["q", "ctrl+c"]` - **quit_keys**: Array of keys that will quit the application. Default: `["q", "ctrl+c"]`
- **disable_esc_quit**: Boolean flag to disable ESC key from quitting the application. Default: `false` - **disable_esc_quit**: Boolean flag to disable ESC key from quitting the application. Default: `false`
**For Vim Users:** **For Vim Users:**
If you frequently press ESC accidentally causing the application to quit, set `disable_esc_quit` to `true`. This will disable ESC as a quit key while preserving all other functionality. If you frequently press ESC accidentally causing the application to quit, set `disable_esc_quit` to `true`. This will disable ESC as a quit key while preserving all other functionality.
**For Air-gapped Machines:**
If SSHM is slow to start due to DNS timeouts when reaching GitHub, set `check_for_updates` to `false`. You can also use the `--no-update-check` CLI flag for a one-time override without editing the config file.
**Default Configuration:** **Default Configuration:**
If no configuration file exists, SSHM will automatically create one with default settings that maintain backward compatibility. If no configuration file exists, SSHM will automatically create one with default settings that maintain backward compatibility.

199
cmd/info.go Normal file
View File

@@ -0,0 +1,199 @@
package cmd
import (
"encoding/json"
"io"
"os"
"strconv"
"strings"
"github.com/Gu1llaum-3/sshm/internal/config"
"github.com/spf13/cobra"
)
type infoResponse struct {
Schema string `json:"schema"`
OK bool `json:"ok"`
Hostname string `json:"hostname"`
Result *infoResult `json:"result"`
Error *infoError `json:"error"`
}
type infoResult struct {
CanonicalName string `json:"canonical_name"`
Target infoTarget `json:"target"`
IdentityFile *string `json:"identity_file"`
ProxyJump *string `json:"proxy_jump"`
ProxyCommand *string `json:"proxy_command"`
Options *string `json:"options"`
Tags []string `json:"tags"`
RemoteCommand *string `json:"remote_command"`
RequestTTY *string `json:"request_tty"`
Source *infoSource `json:"source"`
}
type infoTarget struct {
Host string `json:"host"`
Hostname *string `json:"hostname"`
User *string `json:"user"`
Port *int `json:"port"`
}
type infoSource struct {
File string `json:"file"`
Line int `json:"line"`
}
type infoError struct {
Code string `json:"code"`
Message string `json:"message"`
Details json.RawMessage `json:"details"`
}
func maybeString(v string) *string {
trimmed := strings.TrimSpace(v)
if trimmed == "" {
return nil
}
return &trimmed
}
func maybePort(v string) (*int, error) {
trimmed := strings.TrimSpace(v)
if trimmed == "" {
return nil, nil
}
port, err := strconv.Atoi(trimmed)
if err != nil {
return nil, err
}
return &port, nil
}
func writeInfoJSON(out io.Writer, pretty bool, resp infoResponse) {
var b []byte
var err error
if pretty {
b, err = json.MarshalIndent(resp, "", " ")
} else {
b, err = json.Marshal(resp)
}
if err != nil {
_, _ = io.WriteString(out, `{"schema":"sshm.info.v1","ok":false,"hostname":"","result":null,"error":{"code":"INTERNAL","message":"failed to marshal JSON","details":null}}\n`)
return
}
_, _ = out.Write(append(b, '\n'))
}
func runInfo(out io.Writer, hostnameArg string, cfgFile string, pretty bool) int {
resp := infoResponse{
Schema: "sshm.info.v1",
OK: false,
Hostname: hostnameArg,
Result: nil,
Error: nil,
}
var host *config.SSHHost
var err error
if cfgFile != "" {
host, err = config.GetSSHHostFromFile(hostnameArg, cfgFile)
} else {
host, err = config.GetSSHHost(hostnameArg)
}
if err != nil {
code := 1
errCode := "CONFIG_ERROR"
msg := err.Error()
if strings.Contains(msg, "not found") {
code = 2
errCode = "NOT_FOUND"
}
resp.Error = &infoError{Code: errCode, Message: msg, Details: nil}
writeInfoJSON(out, pretty, resp)
return code
}
port, portErr := maybePort(host.Port)
if portErr != nil {
resp.Error = &infoError{Code: "CONFIG_ERROR", Message: "invalid port in host configuration", Details: nil}
writeInfoJSON(out, pretty, resp)
return 1
}
res := infoResult{
CanonicalName: host.Name,
Target: infoTarget{
Host: hostnameArg,
Hostname: maybeString(host.Hostname),
User: maybeString(host.User),
Port: port,
},
IdentityFile: maybeString(host.Identity),
ProxyJump: maybeString(host.ProxyJump),
ProxyCommand: maybeString(host.ProxyCommand),
Options: maybeString(host.Options),
Tags: host.Tags,
RemoteCommand: maybeString(host.RemoteCommand),
RequestTTY: maybeString(host.RequestTTY),
Source: &infoSource{
File: host.SourceFile,
Line: host.LineNumber,
},
}
resp.OK = true
resp.Result = &res
writeInfoJSON(out, pretty, resp)
return 0
}
var infoPretty bool
var infoCmd = &cobra.Command{
Use: "info <hostname>",
Short: "Print machine-readable information about a host",
Long: "Print machine-readable information (JSON) about a configured SSH host.",
Args: cobra.ExactArgs(1),
SilenceUsage: true,
SilenceErrors: true,
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) != 0 {
return nil, cobra.ShellCompDirectiveNoFileComp
}
var hosts []config.SSHHost
var err error
if configFile != "" {
hosts, err = config.ParseSSHConfigFile(configFile)
} else {
hosts, err = config.ParseSSHConfig()
}
if err != nil {
return nil, cobra.ShellCompDirectiveError
}
var completions []string
toCompleteLower := strings.ToLower(toComplete)
for _, host := range hosts {
if strings.HasPrefix(strings.ToLower(host.Name), toCompleteLower) {
completions = append(completions, host.Name)
}
}
return completions, cobra.ShellCompDirectiveNoFileComp
},
RunE: func(cmd *cobra.Command, args []string) error {
exitCode := runInfo(cmd.OutOrStdout(), args[0], configFile, infoPretty)
if exitCode != 0 {
os.Exit(exitCode)
}
return nil
},
}
func init() {
infoCmd.Flags().BoolVar(&infoPretty, "pretty", false, "Pretty-print JSON output")
RootCmd.AddCommand(infoCmd)
}

321
cmd/info_test.go Normal file
View File

@@ -0,0 +1,321 @@
package cmd
import (
"bytes"
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
"github.com/spf13/cobra"
)
type infoResponseForTest struct {
Schema string `json:"schema"`
OK bool `json:"ok"`
Hostname string `json:"hostname"`
Result *infoResultForTest `json:"result"`
Error *infoErrorForTest `json:"error"`
}
type infoResultForTest struct {
CanonicalName string `json:"canonical_name"`
Target infoTargetForTest `json:"target"`
IdentityFile *string `json:"identity_file"`
ProxyJump *string `json:"proxy_jump"`
ProxyCommand *string `json:"proxy_command"`
Options *string `json:"options"`
Tags []string `json:"tags"`
RemoteCommand *string `json:"remote_command"`
RequestTTY *string `json:"request_tty"`
Source *infoSourceForTest `json:"source"`
}
type infoTargetForTest struct {
Host string `json:"host"`
Hostname *string `json:"hostname"`
User *string `json:"user"`
Port *int `json:"port"`
}
type infoSourceForTest struct {
File string `json:"file"`
Line int `json:"line"`
}
type infoErrorForTest struct {
Code string `json:"code"`
Message string `json:"message"`
Details json.RawMessage `json:"details"`
}
func TestInfoCommandConfig(t *testing.T) {
if infoCmd.Use != "info <hostname>" {
t.Fatalf("infoCmd.Use=%q", infoCmd.Use)
}
err := infoCmd.Args(infoCmd, []string{})
if err == nil {
t.Fatalf("expected args error for no args")
}
err = infoCmd.Args(infoCmd, []string{"one", "two"})
if err == nil {
t.Fatalf("expected args error for too many args")
}
err = infoCmd.Args(infoCmd, []string{"host"})
if err != nil {
t.Fatalf("expected no args error, got %v", err)
}
}
func TestInfoCommandRegistration(t *testing.T) {
found := false
for _, c := range RootCmd.Commands() {
if c.Name() == "info" {
found = true
break
}
}
if !found {
t.Fatalf("info command not registered")
}
}
func TestRunInfoSuccessJSON(t *testing.T) {
tempDir := t.TempDir()
cfg := filepath.Join(tempDir, "config")
cfgContent := `# Tags: prod, web
Host prod-web
HostName 10.0.0.10
User deploy
Port 2222
IdentityFile ~/.ssh/id_prod
ProxyJump bastion
ServerAliveInterval 60
`
if err := os.WriteFile(cfg, []byte(cfgContent), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
buf := new(bytes.Buffer)
exitCode := runInfo(buf, "prod-web", cfg, false)
if exitCode != 0 {
t.Fatalf("exitCode=%d", exitCode)
}
out := buf.String()
if strings.TrimSpace(out) == "" {
t.Fatalf("expected output")
}
var resp infoResponseForTest
if err := json.Unmarshal([]byte(out), &resp); err != nil {
t.Fatalf("output not JSON: %v\noutput=%q", err, out)
}
if resp.Schema != "sshm.info.v1" {
t.Fatalf("schema=%q", resp.Schema)
}
if !resp.OK {
t.Fatalf("ok=false")
}
if resp.Result == nil {
t.Fatalf("result is nil")
}
if resp.Error != nil {
t.Fatalf("error is non-nil")
}
if resp.Result.CanonicalName != "prod-web" {
t.Fatalf("canonical_name=%q", resp.Result.CanonicalName)
}
if resp.Result.Target.Host != "prod-web" {
t.Fatalf("target.host=%q", resp.Result.Target.Host)
}
if resp.Result.Target.Hostname == nil || *resp.Result.Target.Hostname != "10.0.0.10" {
t.Fatalf("target.hostname=%v", resp.Result.Target.Hostname)
}
if resp.Result.Target.User == nil || *resp.Result.Target.User != "deploy" {
t.Fatalf("target.user=%v", resp.Result.Target.User)
}
if resp.Result.Target.Port == nil || *resp.Result.Target.Port != 2222 {
t.Fatalf("target.port=%v", resp.Result.Target.Port)
}
if resp.Result.Source == nil || resp.Result.Source.File == "" || resp.Result.Source.Line == 0 {
t.Fatalf("source missing: %#v", resp.Result.Source)
}
if resp.Result.IdentityFile == nil || *resp.Result.IdentityFile != "~/.ssh/id_prod" {
t.Fatalf("identity_file=%v", resp.Result.IdentityFile)
}
if resp.Result.ProxyJump == nil || *resp.Result.ProxyJump != "bastion" {
t.Fatalf("proxy_jump=%v", resp.Result.ProxyJump)
}
}
func TestRunInfoNotFoundJSON(t *testing.T) {
tempDir := t.TempDir()
cfg := filepath.Join(tempDir, "config")
cfgContent := `Host known
HostName example.com
`
if err := os.WriteFile(cfg, []byte(cfgContent), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
buf := new(bytes.Buffer)
exitCode := runInfo(buf, "missing", cfg, false)
if exitCode != 2 {
t.Fatalf("exitCode=%d", exitCode)
}
var resp infoResponseForTest
if err := json.Unmarshal(buf.Bytes(), &resp); err != nil {
t.Fatalf("output not JSON: %v", err)
}
if resp.OK {
t.Fatalf("ok=true")
}
if resp.Error == nil {
t.Fatalf("error is nil")
}
if resp.Error.Code != "NOT_FOUND" {
t.Fatalf("error.code=%q", resp.Error.Code)
}
}
func TestRunInfoPrettyJSON(t *testing.T) {
tempDir := t.TempDir()
cfg := filepath.Join(tempDir, "config")
cfgContent := `Host known
HostName 127.0.0.1
`
if err := os.WriteFile(cfg, []byte(cfgContent), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
buf := new(bytes.Buffer)
exitCode := runInfo(buf, "known", cfg, true)
if exitCode != 0 {
t.Fatalf("exitCode=%d", exitCode)
}
out := buf.String()
if !strings.Contains(out, "\n") {
t.Fatalf("expected pretty output")
}
var resp infoResponseForTest
if err := json.Unmarshal(buf.Bytes(), &resp); err != nil {
t.Fatalf("output not JSON: %v", err)
}
if !resp.OK {
t.Fatalf("ok=false")
}
}
func TestInfoValidArgsFunction(t *testing.T) {
if infoCmd.ValidArgsFunction == nil {
t.Fatalf("expected ValidArgsFunction to be set on infoCmd")
}
}
func TestInfoValidArgsFunctionWithSSHConfig(t *testing.T) {
tmpDir := t.TempDir()
testConfigFile := filepath.Join(tmpDir, "config")
sshConfig := `Host prod-server
HostName 192.168.1.1
User admin
Host dev-server
HostName 192.168.1.2
User developer
Host staging-db
HostName 192.168.1.3
User dbadmin
`
if err := os.WriteFile(testConfigFile, []byte(sshConfig), 0600); err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
originalConfigFile := configFile
defer func() { configFile = originalConfigFile }()
configFile = testConfigFile
tests := []struct {
name string
toComplete string
args []string
wantCount int
wantHosts []string
}{
{
name: "empty prefix returns all hosts",
toComplete: "",
args: []string{},
wantCount: 3,
wantHosts: []string{"prod-server", "dev-server", "staging-db"},
},
{
name: "prefix filters hosts",
toComplete: "prod",
args: []string{},
wantCount: 1,
wantHosts: []string{"prod-server"},
},
{
name: "prefix case insensitive",
toComplete: "DEV",
args: []string{},
wantCount: 1,
wantHosts: []string{"dev-server"},
},
{
name: "no match returns empty",
toComplete: "nonexistent",
args: []string{},
wantCount: 0,
wantHosts: []string{},
},
{
name: "already has host arg returns nothing",
toComplete: "",
args: []string{"existing-host"},
wantCount: 0,
wantHosts: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
completions, directive := infoCmd.ValidArgsFunction(infoCmd, tt.args, tt.toComplete)
if len(completions) != tt.wantCount {
t.Fatalf("Expected %d completions, got %d: %v", tt.wantCount, len(completions), completions)
}
if directive != cobra.ShellCompDirectiveNoFileComp {
t.Fatalf("Expected ShellCompDirectiveNoFileComp, got %v", directive)
}
for _, wantHost := range tt.wantHosts {
found := false
for _, comp := range completions {
if comp == wantHost {
found = true
break
}
}
if !found {
t.Fatalf("Expected completion %q not found in %v", wantHost, completions)
}
}
})
}
}

View File

@@ -1,19 +1,16 @@
package cmd package cmd
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"syscall" "syscall"
"time"
"github.com/Gu1llaum-3/sshm/internal/config" "github.com/Gu1llaum-3/sshm/internal/config"
"github.com/Gu1llaum-3/sshm/internal/history" "github.com/Gu1llaum-3/sshm/internal/history"
"github.com/Gu1llaum-3/sshm/internal/ui" "github.com/Gu1llaum-3/sshm/internal/ui"
"github.com/Gu1llaum-3/sshm/internal/version"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@@ -30,6 +27,9 @@ var forceTTY bool
// searchMode enables the focus on search mode at startup // searchMode enables the focus on search mode at startup
var searchMode bool var searchMode bool
// noUpdateCheck disables the async update check in the TUI
var noUpdateCheck bool
// RootCmd is the base command when called without any subcommands // RootCmd is the base command when called without any subcommands
var RootCmd = &cobra.Command{ var RootCmd = &cobra.Command{
Use: "sshm [host] [command...]", Use: "sshm [host] [command...]",
@@ -143,7 +143,7 @@ func runInteractiveMode() {
} }
// Run the interactive TUI // Run the interactive TUI
if err := ui.RunInteractiveMode(hosts, configFile, searchMode, AppVersion); err != nil { if err := ui.RunInteractiveMode(hosts, configFile, searchMode, AppVersion, noUpdateCheck); err != nil {
log.Fatalf("Error running interactive mode: %v", err) log.Fatalf("Error running interactive mode: %v", err)
} }
} }
@@ -196,6 +196,15 @@ func connectToHost(hostName string, remoteCommand []string) {
fmt.Printf("Connecting to %s...\n", hostName) fmt.Printf("Connecting to %s...\n", hostName)
} }
sshPath, lookErr := exec.LookPath("ssh")
if lookErr == nil {
argv := append([]string{"ssh"}, args...)
// On Unix, Exec replaces the process and never returns on success.
// On Windows, Exec is not supported and returns an error; fall through to the exec.Command fallback.
_ = syscall.Exec(sshPath, argv, os.Environ())
}
// Fallback for Windows or if LookPath failed
sshCmd := exec.Command("ssh", args...) sshCmd := exec.Command("ssh", args...)
sshCmd.Stdin = os.Stdin sshCmd.Stdin = os.Stdin
sshCmd.Stdout = os.Stdout sshCmd.Stdout = os.Stdout
@@ -213,30 +222,6 @@ func connectToHost(hostName string, remoteCommand []string) {
} }
} }
// getVersionWithUpdateCheck returns a custom version string with update check
func getVersionWithUpdateCheck() string {
versionText := fmt.Sprintf("sshm version %s", AppVersion)
// Check for updates
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
updateInfo, err := version.CheckForUpdates(ctx, AppVersion)
if err != nil {
// Return just version if check fails
return versionText + "\n"
}
if updateInfo != nil && updateInfo.Available {
versionText += fmt.Sprintf("\n🚀 Update available: %s → %s (%s)",
updateInfo.CurrentVer,
updateInfo.LatestVer,
updateInfo.ReleaseURL)
}
return versionText + "\n"
}
// Execute adds all child commands to the root command and sets flags appropriately. // Execute adds all child commands to the root command and sets flags appropriately.
func Execute() { func Execute() {
if err := RootCmd.Execute(); err != nil { if err := RootCmd.Execute(); err != nil {
@@ -258,7 +243,7 @@ func init() {
RootCmd.PersistentFlags().StringVarP(&configFile, "config", "c", "", "SSH config file to use (default: ~/.ssh/config)") RootCmd.PersistentFlags().StringVarP(&configFile, "config", "c", "", "SSH config file to use (default: ~/.ssh/config)")
RootCmd.Flags().BoolVarP(&forceTTY, "tty", "t", false, "Force pseudo-TTY allocation (useful for interactive remote commands)") RootCmd.Flags().BoolVarP(&forceTTY, "tty", "t", false, "Force pseudo-TTY allocation (useful for interactive remote commands)")
RootCmd.PersistentFlags().BoolVarP(&searchMode, "search", "s", false, "Focus on search input at startup") RootCmd.PersistentFlags().BoolVarP(&searchMode, "search", "s", false, "Focus on search input at startup")
RootCmd.PersistentFlags().BoolVar(&noUpdateCheck, "no-update-check", false, "Disable automatic update check")
// Set custom version template with update check RootCmd.SetVersionTemplate("{{.Name}} version {{.Version}}\n")
RootCmd.SetVersionTemplate(getVersionWithUpdateCheck())
} }

View File

@@ -45,7 +45,7 @@ func TestRootCommandFlags(t *testing.T) {
func TestRootCommandSubcommands(t *testing.T) { func TestRootCommandSubcommands(t *testing.T) {
// Test that all expected subcommands are registered // Test that all expected subcommands are registered
// Note: completion and help are automatically added by Cobra and may not always appear in Commands() // Note: completion and help are automatically added by Cobra and may not always appear in Commands()
expectedCommands := []string{"add", "edit", "search"} expectedCommands := []string{"add", "edit", "search", "info"}
commands := RootCmd.Commands() commands := RootCmd.Commands()
commandNames := make(map[string]bool) commandNames := make(map[string]bool)

View File

@@ -18,7 +18,16 @@ type KeyBindings struct {
// AppConfig represents the main application configuration // AppConfig represents the main application configuration
type AppConfig struct { type AppConfig struct {
KeyBindings KeyBindings `json:"key_bindings"` CheckForUpdates *bool `json:"check_for_updates,omitempty"`
KeyBindings KeyBindings `json:"key_bindings"`
}
// IsUpdateCheckEnabled returns true if the update check is enabled (default: true)
func (c *AppConfig) IsUpdateCheckEnabled() bool {
if c == nil || c.CheckForUpdates == nil {
return true
}
return *c.CheckForUpdates
} }
// GetDefaultKeyBindings returns the default key bindings configuration // GetDefaultKeyBindings returns the default key bindings configuration

View File

@@ -104,6 +104,58 @@ func TestAppConfigBasics(t *testing.T) {
if len(defaultConfig.KeyBindings.QuitKeys) != len(expectedQuitKeys) { if len(defaultConfig.KeyBindings.QuitKeys) != len(expectedQuitKeys) {
t.Errorf("Expected %d quit keys, got %d", len(expectedQuitKeys), len(defaultConfig.KeyBindings.QuitKeys)) t.Errorf("Expected %d quit keys, got %d", len(expectedQuitKeys), len(defaultConfig.KeyBindings.QuitKeys))
} }
// CheckForUpdates should be nil by default
if defaultConfig.CheckForUpdates != nil {
t.Error("Default configuration should have CheckForUpdates as nil")
}
// IsUpdateCheckEnabled should return true by default
if !defaultConfig.IsUpdateCheckEnabled() {
t.Error("IsUpdateCheckEnabled should return true when CheckForUpdates is nil")
}
}
func boolPtr(b bool) *bool {
return &b
}
func TestIsUpdateCheckEnabled(t *testing.T) {
tests := []struct {
name string
config *AppConfig
expected bool
}{
{
name: "nil AppConfig returns true",
config: nil,
expected: true,
},
{
name: "CheckForUpdates nil returns true",
config: &AppConfig{},
expected: true,
},
{
name: "CheckForUpdates true returns true",
config: &AppConfig{CheckForUpdates: boolPtr(true)},
expected: true,
},
{
name: "CheckForUpdates false returns false",
config: &AppConfig{CheckForUpdates: boolPtr(false)},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.config.IsUpdateCheckEnabled()
if result != tt.expected {
t.Errorf("IsUpdateCheckEnabled() = %v, expected %v", result, tt.expected)
}
})
}
} }
func TestMergeWithDefaults(t *testing.T) { func TestMergeWithDefaults(t *testing.T) {
@@ -141,6 +193,7 @@ func TestSaveAndLoadAppConfigIntegration(t *testing.T) {
configPath := filepath.Join(tempDir, "config.json") configPath := filepath.Join(tempDir, "config.json")
customConfig := AppConfig{ customConfig := AppConfig{
CheckForUpdates: boolPtr(false),
KeyBindings: KeyBindings{ KeyBindings: KeyBindings{
QuitKeys: []string{"q"}, QuitKeys: []string{"q"},
DisableEscQuit: true, DisableEscQuit: true,
@@ -178,4 +231,15 @@ func TestSaveAndLoadAppConfigIntegration(t *testing.T) {
if len(loadedConfig.KeyBindings.QuitKeys) != 1 || loadedConfig.KeyBindings.QuitKeys[0] != "q" { if len(loadedConfig.KeyBindings.QuitKeys) != 1 || loadedConfig.KeyBindings.QuitKeys[0] != "q" {
t.Errorf("Expected quit keys to be ['q'], got %v", loadedConfig.KeyBindings.QuitKeys) t.Errorf("Expected quit keys to be ['q'], got %v", loadedConfig.KeyBindings.QuitKeys)
} }
// Verify CheckForUpdates is correctly persisted and reloaded
if loadedConfig.CheckForUpdates == nil {
t.Fatal("CheckForUpdates should not be nil after round-trip")
}
if *loadedConfig.CheckForUpdates != false {
t.Errorf("CheckForUpdates should be false after round-trip, got %v", *loadedConfig.CheckForUpdates)
}
if loadedConfig.IsUpdateCheckEnabled() {
t.Error("IsUpdateCheckEnabled should return false when CheckForUpdates is false")
}
} }

View File

@@ -11,6 +11,18 @@ import (
"sync" "sync"
) )
func getHomeDir() (string, error) {
home := os.Getenv("HOME")
if home != "" {
return home, nil
}
home = os.Getenv("USERPROFILE")
if home != "" {
return home, nil
}
return os.UserHomeDir()
}
// SSHHost represents an SSH host configuration // SSHHost represents an SSH host configuration
type SSHHost struct { type SSHHost struct {
Name string Name string
@@ -33,7 +45,7 @@ type SSHHost struct {
// GetDefaultSSHConfigPath returns the default SSH config path for the current platform // GetDefaultSSHConfigPath returns the default SSH config path for the current platform
func GetDefaultSSHConfigPath() (string, error) { func GetDefaultSSHConfigPath() (string, error) {
homeDir, err := os.UserHomeDir() homeDir, err := getHomeDir()
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -49,7 +61,7 @@ func GetDefaultSSHConfigPath() (string, error) {
// GetSSHMConfigDir returns the SSHM config directory // GetSSHMConfigDir returns the SSHM config directory
func GetSSHMConfigDir() (string, error) { func GetSSHMConfigDir() (string, error) {
homeDir, err := os.UserHomeDir() homeDir, err := getHomeDir()
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -88,7 +100,7 @@ func GetSSHMBackupDir() (string, error) {
// GetSSHDirectory returns the .ssh directory path // GetSSHDirectory returns the .ssh directory path
func GetSSHDirectory() (string, error) { func GetSSHDirectory() (string, error) {
homeDir, err := os.UserHomeDir() homeDir, err := getHomeDir()
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -386,7 +398,7 @@ func parseSSHConfigFileWithProcessedFiles(configPath string, processedFiles map[
func processIncludeDirective(pattern string, baseConfigPath string, processedFiles map[string]bool) ([]SSHHost, error) { func processIncludeDirective(pattern string, baseConfigPath string, processedFiles map[string]bool) ([]SSHHost, error) {
// Expand tilde to home directory // Expand tilde to home directory
if strings.HasPrefix(pattern, "~") { if strings.HasPrefix(pattern, "~") {
homeDir, err := os.UserHomeDir() homeDir, err := getHomeDir()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err) return nil, fmt.Errorf("failed to get home directory: %w", err)
} }
@@ -918,7 +930,7 @@ func quickHostSearchInFile(hostName string, configPath string, processedFiles ma
func quickSearchInclude(hostName, pattern, baseConfigPath string, processedFiles map[string]bool) (bool, error) { func quickSearchInclude(hostName, pattern, baseConfigPath string, processedFiles map[string]bool) (bool, error) {
// Expand tilde to home directory // Expand tilde to home directory
if strings.HasPrefix(pattern, "~") { if strings.HasPrefix(pattern, "~") {
homeDir, err := os.UserHomeDir() homeDir, err := getHomeDir()
if err != nil { if err != nil {
return false, fmt.Errorf("failed to get home directory: %w", err) return false, fmt.Errorf("failed to get home directory: %w", err)
} }

View File

@@ -456,15 +456,21 @@ func TestBackupConfigToSSHMDirectory(t *testing.T) {
// Create temporary directory for test files // Create temporary directory for test files
tempDir := t.TempDir() tempDir := t.TempDir()
// Override the home directory for this test
originalHome := os.Getenv("HOME") originalHome := os.Getenv("HOME")
if originalHome == "" { if originalHome == "" {
originalHome = os.Getenv("USERPROFILE") // Windows originalHome = os.Getenv("USERPROFILE")
} }
originalXDG := os.Getenv("XDG_CONFIG_HOME")
originalAppData := os.Getenv("APPDATA")
// Set test home directory
os.Setenv("HOME", tempDir) os.Setenv("HOME", tempDir)
defer os.Setenv("HOME", originalHome) os.Setenv("XDG_CONFIG_HOME", tempDir)
os.Setenv("APPDATA", tempDir)
defer func() {
os.Setenv("HOME", originalHome)
os.Setenv("XDG_CONFIG_HOME", originalXDG)
os.Setenv("APPDATA", originalAppData)
}()
// Create a test SSH config file // Create a test SSH config file
sshDir := filepath.Join(tempDir, ".ssh") sshDir := filepath.Join(tempDir, ".ssh")

View File

@@ -2,12 +2,15 @@ package connectivity
import ( import (
"context" "context"
"fmt"
"net" "net"
"github.com/Gu1llaum-3/sshm/internal/config" "os/exec"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/Gu1llaum-3/sshm/internal/config"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -45,16 +48,18 @@ type HostPingResult struct {
// PingManager manages SSH connectivity checks for multiple hosts // PingManager manages SSH connectivity checks for multiple hosts
type PingManager struct { type PingManager struct {
results map[string]*HostPingResult results map[string]*HostPingResult
mutex sync.RWMutex mutex sync.RWMutex
timeout time.Duration timeout time.Duration
configFile string
} }
// NewPingManager creates a new ping manager with the specified timeout // NewPingManager creates a new ping manager with the specified timeout
func NewPingManager(timeout time.Duration) *PingManager { func NewPingManager(timeout time.Duration, configFile string) *PingManager {
return &PingManager{ return &PingManager{
results: make(map[string]*HostPingResult), results: make(map[string]*HostPingResult),
timeout: timeout, timeout: timeout,
configFile: configFile,
} }
} }
@@ -98,6 +103,14 @@ func (pm *PingManager) PingHost(ctx context.Context, host config.SSHHost) *HostP
// Mark as connecting // Mark as connecting
pm.updateStatus(host.Name, StatusConnecting, nil, 0) pm.updateStatus(host.Name, StatusConnecting, nil, 0)
// If the host uses a ProxyJump or ProxyCommand, we need to use the external SSH command
// because implementing jump host support with pure Go ssh library requires
// handling authentication for the jump host, which is complex and requires
// access to the user's SSH agent or keys.
if host.ProxyJump != "" || host.ProxyCommand != "" {
return pm.pingWithExternalCommand(ctx, host, start)
}
// Determine the actual hostname and port // Determine the actual hostname and port
hostname := host.Hostname hostname := host.Hostname
if hostname == "" { if hostname == "" {
@@ -159,6 +172,53 @@ func (pm *PingManager) PingHost(ctx context.Context, host config.SSHHost) *HostP
} }
} }
// pingWithExternalCommand pings a host using the external SSH command
func (pm *PingManager) pingWithExternalCommand(ctx context.Context, host config.SSHHost, start time.Time) *HostPingResult {
// Construct the SSH command
// ssh -q -o BatchMode=yes -o StrictHostKeyChecking=no -o ConnectTimeout=5 host exit
args := []string{"-q", "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=no"}
// Set timeout matching the manager's timeout
// Convert duration to seconds (rounding up to ensure we don't timeout too early in the command)
timeoutSec := int(pm.timeout.Seconds())
if timeoutSec < 1 {
timeoutSec = 1
}
args = append(args, "-o", fmt.Sprintf("ConnectTimeout=%d", timeoutSec))
// If we have a specific config file, use it
if pm.configFile != "" {
args = append(args, "-F", pm.configFile)
}
// Add the host name and the command to run (exit)
args = append(args, host.Name, "exit")
// Create command with context for timeout cancellation
// Note: We used pm.timeout for the ssh command option, but we also respect the context deadline
cmd := exec.CommandContext(ctx, "ssh", args...)
// Run the command
err := cmd.Run()
duration := time.Since(start)
var status PingStatus
if err != nil {
// SSH returns non-zero exit code on connection failure
status = StatusOffline
} else {
status = StatusOnline
}
pm.updateStatus(host.Name, status, err, duration)
return &HostPingResult{
HostName: host.Name,
Status: status,
Error: err,
Duration: duration,
}
}
// PingAllHosts pings all hosts concurrently and returns a channel of results // PingAllHosts pings all hosts concurrently and returns a channel of results
func (pm *PingManager) PingAllHosts(ctx context.Context, hosts []config.SSHHost) <-chan *HostPingResult { func (pm *PingManager) PingAllHosts(ctx context.Context, hosts []config.SSHHost) <-chan *HostPingResult {
resultChan := make(chan *HostPingResult, len(hosts)) resultChan := make(chan *HostPingResult, len(hosts))

View File

@@ -9,7 +9,7 @@ import (
) )
func TestNewPingManager(t *testing.T) { func TestNewPingManager(t *testing.T) {
pm := NewPingManager(5 * time.Second) pm := NewPingManager(5*time.Second, "")
if pm == nil { if pm == nil {
t.Error("NewPingManager() returned nil") t.Error("NewPingManager() returned nil")
} }
@@ -19,16 +19,16 @@ func TestNewPingManager(t *testing.T) {
} }
func TestPingManager_PingHost(t *testing.T) { func TestPingManager_PingHost(t *testing.T) {
pm := NewPingManager(1 * time.Second) pm := NewPingManager(1*time.Second, "")
ctx := context.Background() ctx := context.Background()
// Test ping method exists and doesn't panic // Test ping method exists and doesn't panic
host := config.SSHHost{Name: "test", Hostname: "127.0.0.1", Port: "22"} host := config.SSHHost{Name: "test", Hostname: "127.0.0.1", Port: "22"}
result := pm.PingHost(ctx, host) result := pm.PingHost(ctx, host)
if result == nil { if result == nil {
t.Error("Expected ping result to be returned") t.Error("Expected ping result to be returned")
} }
// Test with invalid host // Test with invalid host
invalidHost := config.SSHHost{Name: "invalid", Hostname: "invalid.host.12345", Port: "22"} invalidHost := config.SSHHost{Name: "invalid", Hostname: "invalid.host.12345", Port: "22"}
result = pm.PingHost(ctx, invalidHost) result = pm.PingHost(ctx, invalidHost)
@@ -38,14 +38,14 @@ func TestPingManager_PingHost(t *testing.T) {
} }
func TestPingManager_GetStatus(t *testing.T) { func TestPingManager_GetStatus(t *testing.T) {
pm := NewPingManager(1 * time.Second) pm := NewPingManager(1*time.Second, "")
// Test unknown host // Test unknown host
status := pm.GetStatus("unknown.host") status := pm.GetStatus("unknown.host")
if status != StatusUnknown { if status != StatusUnknown {
t.Errorf("Expected StatusUnknown for unknown host, got %v", status) t.Errorf("Expected StatusUnknown for unknown host, got %v", status)
} }
// Test after ping // Test after ping
ctx := context.Background() ctx := context.Background()
host := config.SSHHost{Name: "test", Hostname: "127.0.0.1", Port: "22"} host := config.SSHHost{Name: "test", Hostname: "127.0.0.1", Port: "22"}
@@ -57,21 +57,21 @@ func TestPingManager_GetStatus(t *testing.T) {
} }
func TestPingManager_PingMultipleHosts(t *testing.T) { func TestPingManager_PingMultipleHosts(t *testing.T) {
pm := NewPingManager(1 * time.Second) pm := NewPingManager(1*time.Second, "")
hosts := []config.SSHHost{ hosts := []config.SSHHost{
{Name: "localhost", Hostname: "127.0.0.1", Port: "22"}, {Name: "localhost", Hostname: "127.0.0.1", Port: "22"},
{Name: "invalid", Hostname: "invalid.host.12345", Port: "22"}, {Name: "invalid", Hostname: "invalid.host.12345", Port: "22"},
} }
ctx := context.Background() ctx := context.Background()
// Ping each host individually // Ping each host individually
for _, host := range hosts { for _, host := range hosts {
result := pm.PingHost(ctx, host) result := pm.PingHost(ctx, host)
if result == nil { if result == nil {
t.Errorf("Expected ping result for host %s", host.Name) t.Errorf("Expected ping result for host %s", host.Name)
} }
// Check that status was set // Check that status was set
status := pm.GetStatus(host.Name) status := pm.GetStatus(host.Name)
if status == StatusUnknown { if status == StatusUnknown {
@@ -81,19 +81,19 @@ func TestPingManager_PingMultipleHosts(t *testing.T) {
} }
func TestPingManager_GetResult(t *testing.T) { func TestPingManager_GetResult(t *testing.T) {
pm := NewPingManager(1 * time.Second) pm := NewPingManager(1*time.Second, "")
ctx := context.Background() ctx := context.Background()
// Test getting result for unknown host // Test getting result for unknown host
result, exists := pm.GetResult("unknown") result, exists := pm.GetResult("unknown")
if exists || result != nil { if exists || result != nil {
t.Error("Expected no result for unknown host") t.Error("Expected no result for unknown host")
} }
// Test after ping // Test after ping
host := config.SSHHost{Name: "test", Hostname: "127.0.0.1", Port: "22"} host := config.SSHHost{Name: "test", Hostname: "127.0.0.1", Port: "22"}
pm.PingHost(ctx, host) pm.PingHost(ctx, host)
result, exists = pm.GetResult("test") result, exists = pm.GetResult("test")
if !exists || result == nil { if !exists || result == nil {
t.Error("Expected result to exist after ping") t.Error("Expected result to exist after ping")
@@ -114,7 +114,7 @@ func TestPingStatus_String(t *testing.T) {
{StatusOffline, "offline"}, {StatusOffline, "offline"},
{PingStatus(999), "unknown"}, // Invalid status {PingStatus(999), "unknown"}, // Invalid status
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) { t.Run(tt.expected, func(t *testing.T) {
if got := tt.status.String(); got != tt.expected { if got := tt.status.String(); got != tt.expected {
@@ -126,19 +126,19 @@ func TestPingStatus_String(t *testing.T) {
func TestPingHost_Basic(t *testing.T) { func TestPingHost_Basic(t *testing.T) {
// Test that the ping functionality exists // Test that the ping functionality exists
pm := NewPingManager(1 * time.Second) pm := NewPingManager(1*time.Second, "")
ctx := context.Background() ctx := context.Background()
host := config.SSHHost{Name: "test", Hostname: "127.0.0.1", Port: "22"} host := config.SSHHost{Name: "test", Hostname: "127.0.0.1", Port: "22"}
// Just ensure the function doesn't panic // Just ensure the function doesn't panic
result := pm.PingHost(ctx, host) result := pm.PingHost(ctx, host)
if result == nil { if result == nil {
t.Error("Expected ping result to be returned") t.Error("Expected ping result to be returned")
} }
// Test that status is set // Test that status is set
status := pm.GetStatus("test") status := pm.GetStatus("test")
if status == StatusUnknown { if status == StatusUnknown {
t.Error("Expected status to be set after ping attempt") t.Error("Expected status to be set after ping attempt")
} }
} }

View File

@@ -16,7 +16,7 @@ import (
) )
// NewModel creates a new TUI model with the given SSH hosts // NewModel creates a new TUI model with the given SSH hosts
func NewModel(hosts []config.SSHHost, configFile string, searchMode bool, currentVersion string) Model { func NewModel(hosts []config.SSHHost, configFile string, searchMode bool, currentVersion string, noUpdateCheck bool) Model {
// Load application configuration // Load application configuration
appConfig, err := config.LoadAppConfig() appConfig, err := config.LoadAppConfig()
if err != nil { if err != nil {
@@ -26,6 +26,12 @@ func NewModel(hosts []config.SSHHost, configFile string, searchMode bool, curren
appConfig = &defaultConfig appConfig = &defaultConfig
} }
// CLI flag overrides config file setting
if noUpdateCheck {
f := false
appConfig.CheckForUpdates = &f
}
// Initialize the history manager // Initialize the history manager
historyManager, err := history.NewHistoryManager() historyManager, err := history.NewHistoryManager()
if err != nil { if err != nil {
@@ -38,7 +44,7 @@ func NewModel(hosts []config.SSHHost, configFile string, searchMode bool, curren
styles := NewStyles(80) // Default width styles := NewStyles(80) // Default width
// Initialize ping manager with 5 second timeout // Initialize ping manager with 5 second timeout
pingManager := connectivity.NewPingManager(5 * time.Second) pingManager := connectivity.NewPingManager(5*time.Second, configFile)
// Create the model with default sorting by name // Create the model with default sorting by name
m := Model{ m := Model{
@@ -151,8 +157,8 @@ func NewModel(hosts []config.SSHHost, configFile string, searchMode bool, curren
} }
// RunInteractiveMode starts the interactive TUI interface // RunInteractiveMode starts the interactive TUI interface
func RunInteractiveMode(hosts []config.SSHHost, configFile string, searchMode bool, currentVersion string) error { func RunInteractiveMode(hosts []config.SSHHost, configFile string, searchMode bool, currentVersion string, noUpdateCheck bool) error {
m := NewModel(hosts, configFile, searchMode, currentVersion) m := NewModel(hosts, configFile, searchMode, currentVersion, noUpdateCheck)
// Start the application in alt screen mode for clean output // Start the application in alt screen mode for clean output
p := tea.NewProgram(m, tea.WithAltScreen()) p := tea.NewProgram(m, tea.WithAltScreen())

View File

@@ -74,8 +74,8 @@ func (m Model) Init() tea.Cmd {
// Basic initialization commands // Basic initialization commands
cmds = append(cmds, textinput.Blink) cmds = append(cmds, textinput.Blink)
// Check for version updates if we have a current version // Check for version updates if we have a current version and updates are enabled
if m.currentVersion != "" { if m.currentVersion != "" && m.appConfig.IsUpdateCheckEnabled() {
cmds = append(cmds, checkVersionCmd(m.currentVersion)) cmds = append(cmds, checkVersionCmd(m.currentVersion))
} }

View File

@@ -66,6 +66,25 @@ func ValidateIdentityFile(path string) bool {
if path == "" { if path == "" {
return true // Optional field return true // Optional field
} }
// SSH tokens (e.g. %d, %h, %r, %u) are resolved by SSH at connection time
sshTokenRegex := regexp.MustCompile(`%[hprunCdiklLT]`)
if sshTokenRegex.MatchString(path) {
return true
}
// Expand environment variables ($VAR and ${VAR}); track undefined ones
hasUndefined := false
path = os.Expand(path, func(key string) string {
val, ok := os.LookupEnv(key)
if !ok {
hasUndefined = true
return "$" + key
}
return val
})
// If any variable was undefined, accept the path (SSH will report the error)
if hasUndefined {
return true
}
// Expand ~ to home directory // Expand ~ to home directory
if strings.HasPrefix(path, "~/") { if strings.HasPrefix(path, "~/") {
homeDir, err := os.UserHomeDir() homeDir, err := os.UserHomeDir()

View File

@@ -133,6 +133,9 @@ func TestValidateIdentityFile(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// Set up an env var pointing to the valid file's directory for env var tests
t.Setenv("TEST_SSHM_DIR", tmpDir)
tests := []struct { tests := []struct {
name string name string
path string path string
@@ -143,6 +146,13 @@ func TestValidateIdentityFile(t *testing.T) {
{"non-existent file", "/path/to/nonexistent", false}, {"non-existent file", "/path/to/nonexistent", false},
// Skip tilde path test in CI environments where ~/.ssh/id_rsa may not exist // Skip tilde path test in CI environments where ~/.ssh/id_rsa may not exist
// {"tilde path", "~/.ssh/id_rsa", true}, // Will pass if file exists // {"tilde path", "~/.ssh/id_rsa", true}, // Will pass if file exists
// Environment variable expansion (issue #33)
{"env var $VAR/key defined", "$TEST_SSHM_DIR/test_key", true},
{"env var ${VAR}/key defined", "${TEST_SSHM_DIR}/test_key", true},
{"env var undefined", "$UNDEFINED_SSHM_VAR_XYZ/key", true},
// SSH tokens
{"SSH token %d", "%d/.ssh/id_rsa", true},
{"SSH token %h", "%h-key", true},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -170,6 +180,7 @@ func TestValidateHost(t *testing.T) {
if err := os.WriteFile(validIdentity, []byte("test"), 0600); err != nil { if err := os.WriteFile(validIdentity, []byte("test"), 0600); err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Setenv("TEST_SSHM_HOST_DIR", tmpDir)
tests := []struct { tests := []struct {
name string name string
@@ -187,6 +198,9 @@ func TestValidateHost(t *testing.T) {
{"invalid hostname", "myserver", "invalid..hostname", "22", "", true}, {"invalid hostname", "myserver", "invalid..hostname", "22", "", true},
{"invalid port", "myserver", "example.com", "99999", "", true}, {"invalid port", "myserver", "example.com", "99999", "", true},
{"invalid identity", "myserver", "example.com", "22", "/nonexistent", true}, {"invalid identity", "myserver", "example.com", "22", "/nonexistent", true},
// Environment variables and SSH tokens in identity (issue #33)
{"identity with env var", "myserver", "example.com", "22", "$TEST_SSHM_HOST_DIR/test_key", false},
{"identity with SSH token", "myserver", "example.com", "22", "%d/.ssh/id_rsa", false},
} }
for _, tt := range tests { for _, tt := range tests {