feat: add port forwarding history persistence

This commit is contained in:
Gu1llaum-3 2025-09-12 11:29:17 +02:00
parent 3da3a33530
commit ab5f430ee1
4 changed files with 334 additions and 37 deletions

View File

@ -15,11 +15,21 @@ type ConnectionHistory struct {
Connections map[string]ConnectionInfo `json:"connections"` Connections map[string]ConnectionInfo `json:"connections"`
} }
// PortForwardConfig stores port forwarding configuration
type PortForwardConfig struct {
Type string `json:"type"` // "local", "remote", "dynamic"
LocalPort string `json:"local_port"`
RemoteHost string `json:"remote_host"`
RemotePort string `json:"remote_port"`
BindAddress string `json:"bind_address"`
}
// ConnectionInfo stores information about a specific connection // ConnectionInfo stores information about a specific connection
type ConnectionInfo struct { type ConnectionInfo struct {
HostName string `json:"host_name"` HostName string `json:"host_name"`
LastConnect time.Time `json:"last_connect"` LastConnect time.Time `json:"last_connect"`
ConnectCount int `json:"connect_count"` ConnectCount int `json:"connect_count"`
PortForwarding *PortForwardConfig `json:"port_forwarding,omitempty"`
} }
// HistoryManager manages the connection history // HistoryManager manages the connection history
@ -257,3 +267,42 @@ func (hm *HistoryManager) GetAllConnectionsInfo() []ConnectionInfo {
return connections return connections
} }
// RecordPortForwarding saves port forwarding configuration for a host
func (hm *HistoryManager) RecordPortForwarding(hostName, forwardType, localPort, remoteHost, remotePort, bindAddress string) error {
now := time.Now()
portForwardConfig := &PortForwardConfig{
Type: forwardType,
LocalPort: localPort,
RemoteHost: remoteHost,
RemotePort: remotePort,
BindAddress: bindAddress,
}
if conn, exists := hm.history.Connections[hostName]; exists {
// Update existing connection
conn.LastConnect = now
conn.ConnectCount++
conn.PortForwarding = portForwardConfig
hm.history.Connections[hostName] = conn
} else {
// Create new connection record
hm.history.Connections[hostName] = ConnectionInfo{
HostName: hostName,
LastConnect: now,
ConnectCount: 1,
PortForwarding: portForwardConfig,
}
}
return hm.saveHistory()
}
// GetPortForwardingConfig retrieves the last used port forwarding configuration for a host
func (hm *HistoryManager) GetPortForwardingConfig(hostName string) *PortForwardConfig {
if conn, exists := hm.history.Connections[hostName]; exists {
return conn.PortForwarding
}
return nil
}

View File

@ -0,0 +1,183 @@
package history
import (
"os"
"path/filepath"
"testing"
)
func TestPortForwardingHistory(t *testing.T) {
// Create temporary directory for testing
tempDir, err := os.MkdirTemp("", "sshm_test_*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
// Create history manager with temp directory
historyPath := filepath.Join(tempDir, "test_history.json")
hm := &HistoryManager{
historyPath: historyPath,
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
}
hostName := "test-server"
// Test recording port forwarding configuration
err = hm.RecordPortForwarding(hostName, "local", "8080", "localhost", "80", "127.0.0.1")
if err != nil {
t.Fatalf("Failed to record port forwarding: %v", err)
}
// Test retrieving port forwarding configuration
config := hm.GetPortForwardingConfig(hostName)
if config == nil {
t.Fatalf("Expected port forwarding config to exist")
}
// Verify the saved configuration
if config.Type != "local" {
t.Errorf("Expected Type 'local', got %s", config.Type)
}
if config.LocalPort != "8080" {
t.Errorf("Expected LocalPort '8080', got %s", config.LocalPort)
}
if config.RemoteHost != "localhost" {
t.Errorf("Expected RemoteHost 'localhost', got %s", config.RemoteHost)
}
if config.RemotePort != "80" {
t.Errorf("Expected RemotePort '80', got %s", config.RemotePort)
}
if config.BindAddress != "127.0.0.1" {
t.Errorf("Expected BindAddress '127.0.0.1', got %s", config.BindAddress)
}
// Test updating configuration with different values
err = hm.RecordPortForwarding(hostName, "remote", "3000", "app-server", "8000", "")
if err != nil {
t.Fatalf("Failed to record updated port forwarding: %v", err)
}
// Verify the updated configuration
config = hm.GetPortForwardingConfig(hostName)
if config == nil {
t.Fatalf("Expected port forwarding config to exist after update")
}
if config.Type != "remote" {
t.Errorf("Expected updated Type 'remote', got %s", config.Type)
}
if config.LocalPort != "3000" {
t.Errorf("Expected updated LocalPort '3000', got %s", config.LocalPort)
}
if config.RemoteHost != "app-server" {
t.Errorf("Expected updated RemoteHost 'app-server', got %s", config.RemoteHost)
}
if config.RemotePort != "8000" {
t.Errorf("Expected updated RemotePort '8000', got %s", config.RemotePort)
}
if config.BindAddress != "" {
t.Errorf("Expected updated BindAddress to be empty, got %s", config.BindAddress)
}
// Test dynamic forwarding
err = hm.RecordPortForwarding(hostName, "dynamic", "1080", "", "", "0.0.0.0")
if err != nil {
t.Fatalf("Failed to record dynamic port forwarding: %v", err)
}
config = hm.GetPortForwardingConfig(hostName)
if config == nil {
t.Fatalf("Expected port forwarding config to exist for dynamic forwarding")
}
if config.Type != "dynamic" {
t.Errorf("Expected Type 'dynamic', got %s", config.Type)
}
if config.LocalPort != "1080" {
t.Errorf("Expected LocalPort '1080', got %s", config.LocalPort)
}
if config.RemoteHost != "" {
t.Errorf("Expected RemoteHost to be empty for dynamic forwarding, got %s", config.RemoteHost)
}
if config.RemotePort != "" {
t.Errorf("Expected RemotePort to be empty for dynamic forwarding, got %s", config.RemotePort)
}
if config.BindAddress != "0.0.0.0" {
t.Errorf("Expected BindAddress '0.0.0.0', got %s", config.BindAddress)
}
}
func TestPortForwardingHistoryPersistence(t *testing.T) {
// Create temporary directory for testing
tempDir, err := os.MkdirTemp("", "sshm_test_*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
historyPath := filepath.Join(tempDir, "test_history.json")
// Create first history manager and record data
hm1 := &HistoryManager{
historyPath: historyPath,
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
}
hostName := "persistent-server"
err = hm1.RecordPortForwarding(hostName, "local", "9090", "db-server", "5432", "")
if err != nil {
t.Fatalf("Failed to record port forwarding: %v", err)
}
// Create second history manager and load data
hm2 := &HistoryManager{
historyPath: historyPath,
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
}
err = hm2.loadHistory()
if err != nil {
t.Fatalf("Failed to load history: %v", err)
}
// Verify the loaded configuration
config := hm2.GetPortForwardingConfig(hostName)
if config == nil {
t.Fatalf("Expected port forwarding config to be loaded from file")
}
if config.Type != "local" {
t.Errorf("Expected loaded Type 'local', got %s", config.Type)
}
if config.LocalPort != "9090" {
t.Errorf("Expected loaded LocalPort '9090', got %s", config.LocalPort)
}
if config.RemoteHost != "db-server" {
t.Errorf("Expected loaded RemoteHost 'db-server', got %s", config.RemoteHost)
}
if config.RemotePort != "5432" {
t.Errorf("Expected loaded RemotePort '5432', got %s", config.RemotePort)
}
}
func TestGetPortForwardingConfigNonExistent(t *testing.T) {
// Create temporary directory for testing
tempDir, err := os.MkdirTemp("", "sshm_test_*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
historyPath := filepath.Join(tempDir, "test_history.json")
hm := &HistoryManager{
historyPath: historyPath,
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
}
// Test getting configuration for non-existent host
config := hm.GetPortForwardingConfig("non-existent-host")
if config != nil {
t.Errorf("Expected nil config for non-existent host, got %+v", config)
}
}

View File

@ -5,6 +5,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/Gu1llaum-3/sshm/internal/history"
"github.com/charmbracelet/bubbles/textinput" "github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
@ -20,15 +21,16 @@ const (
) )
type portForwardModel struct { type portForwardModel struct {
inputs []textinput.Model inputs []textinput.Model
focused int focused int
forwardType PortForwardType forwardType PortForwardType
hostName string hostName string
err string err string
styles Styles styles Styles
width int width int
height int height int
configFile string configFile string
historyManager *history.HistoryManager
} }
// portForwardSubmitMsg is sent when the port forward form is submitted // portForwardSubmitMsg is sent when the port forward form is submitted
@ -41,7 +43,7 @@ type portForwardSubmitMsg struct {
type portForwardCancelMsg struct{} type portForwardCancelMsg struct{}
// NewPortForwardForm creates a new port forward form model // NewPortForwardForm creates a new port forward form model
func NewPortForwardForm(hostName string, styles Styles, width, height int, configFile string) *portForwardModel { func NewPortForwardForm(hostName string, styles Styles, width, height int, configFile string, historyManager *history.HistoryManager) *portForwardModel {
inputs := make([]textinput.Model, 5) inputs := make([]textinput.Model, 5)
// Forward type input (display only, controlled by arrow keys) // Forward type input (display only, controlled by arrow keys)
@ -49,7 +51,6 @@ func NewPortForwardForm(hostName string, styles Styles, width, height int, confi
inputs[pfTypeInput].Placeholder = "Use ←/→ to change forward type" inputs[pfTypeInput].Placeholder = "Use ←/→ to change forward type"
inputs[pfTypeInput].Focus() inputs[pfTypeInput].Focus()
inputs[pfTypeInput].Width = 40 inputs[pfTypeInput].Width = 40
inputs[pfTypeInput].SetValue("Local (-L)")
// Local port input // Local port input
inputs[pfLocalPortInput] = textinput.New() inputs[pfLocalPortInput] = textinput.New()
@ -77,16 +78,20 @@ func NewPortForwardForm(hostName string, styles Styles, width, height int, confi
inputs[pfBindAddressInput].Width = 30 inputs[pfBindAddressInput].Width = 30
pf := &portForwardModel{ pf := &portForwardModel{
inputs: inputs, inputs: inputs,
focused: 0, focused: 0,
forwardType: LocalForward, forwardType: LocalForward,
hostName: hostName, hostName: hostName,
styles: styles, styles: styles,
width: width, width: width,
height: height, height: height,
configFile: configFile, configFile: configFile,
historyManager: historyManager,
} }
// Load previous port forwarding configuration if available
pf.loadPreviousConfig()
// Initialize input visibility // Initialize input visibility
pf.updateInputVisibility() pf.updateInputVisibility()
@ -370,6 +375,11 @@ func (m *portForwardModel) submitForm() tea.Cmd {
return portForwardSubmitMsg{err: fmt.Errorf("invalid port number"), sshArgs: nil} return portForwardSubmitMsg{err: fmt.Errorf("invalid port number"), sshArgs: nil}
} }
// Get form values for saving to history
remoteHost := strings.TrimSpace(m.inputs[pfRemoteHostInput].Value())
remotePort := strings.TrimSpace(m.inputs[pfRemotePortInput].Value())
bindAddress := strings.TrimSpace(m.inputs[pfBindAddressInput].Value())
// Build SSH command with port forwarding // Build SSH command with port forwarding
var sshArgs []string var sshArgs []string
@ -379,13 +389,10 @@ func (m *portForwardModel) submitForm() tea.Cmd {
} }
// Add forwarding arguments // Add forwarding arguments
bindAddress := strings.TrimSpace(m.inputs[pfBindAddressInput].Value()) var forwardTypeStr string
switch m.forwardType { switch m.forwardType {
case LocalForward: case LocalForward:
remoteHost := strings.TrimSpace(m.inputs[pfRemoteHostInput].Value()) forwardTypeStr = "local"
remotePort := strings.TrimSpace(m.inputs[pfRemotePortInput].Value())
if remoteHost == "" { if remoteHost == "" {
remoteHost = "localhost" remoteHost = "localhost"
} }
@ -408,31 +415,30 @@ func (m *portForwardModel) submitForm() tea.Cmd {
sshArgs = append(sshArgs, "-L", forwardArg) sshArgs = append(sshArgs, "-L", forwardArg)
case RemoteForward: case RemoteForward:
localHost := strings.TrimSpace(m.inputs[pfRemoteHostInput].Value()) forwardTypeStr = "remote"
localPortStr := strings.TrimSpace(m.inputs[pfRemotePortInput].Value()) if remoteHost == "" {
remoteHost = "localhost"
if localHost == "" {
localHost = "localhost"
} }
if localPortStr == "" { if remotePort == "" {
return portForwardSubmitMsg{err: fmt.Errorf("local port is required for remote forwarding"), sshArgs: nil} return portForwardSubmitMsg{err: fmt.Errorf("local port is required for remote forwarding"), sshArgs: nil}
} }
// Validate local port // Validate local port
if _, err := strconv.Atoi(localPortStr); err != nil { if _, err := strconv.Atoi(remotePort); err != nil {
return portForwardSubmitMsg{err: fmt.Errorf("invalid local port number"), sshArgs: nil} return portForwardSubmitMsg{err: fmt.Errorf("invalid local port number"), sshArgs: nil}
} }
// Build -R argument (note: localPort is actually the remote port in this context) // Build -R argument (note: localPort is actually the remote port in this context)
var forwardArg string var forwardArg string
if bindAddress != "" { if bindAddress != "" {
forwardArg = fmt.Sprintf("%s:%s:%s:%s", bindAddress, localPort, localHost, localPortStr) forwardArg = fmt.Sprintf("%s:%s:%s:%s", bindAddress, localPort, remoteHost, remotePort)
} else { } else {
forwardArg = fmt.Sprintf("%s:%s:%s", localPort, localHost, localPortStr) forwardArg = fmt.Sprintf("%s:%s:%s", localPort, remoteHost, remotePort)
} }
sshArgs = append(sshArgs, "-R", forwardArg) sshArgs = append(sshArgs, "-R", forwardArg)
case DynamicForward: case DynamicForward:
forwardTypeStr = "dynamic"
// Build -D argument // Build -D argument
var forwardArg string var forwardArg string
if bindAddress != "" { if bindAddress != "" {
@ -443,6 +449,21 @@ func (m *portForwardModel) submitForm() tea.Cmd {
sshArgs = append(sshArgs, "-D", forwardArg) sshArgs = append(sshArgs, "-D", forwardArg)
} }
// Save port forwarding configuration to history
if m.historyManager != nil {
if err := m.historyManager.RecordPortForwarding(
m.hostName,
forwardTypeStr,
localPort,
remoteHost,
remotePort,
bindAddress,
); err != nil {
// Log the error but don't fail the connection
// In a production environment, you might want to handle this differently
}
}
// Add hostname // Add hostname
sshArgs = append(sshArgs, m.hostName) sshArgs = append(sshArgs, m.hostName)
@ -488,3 +509,47 @@ func (m *portForwardModel) getPrevValidField(currentField int) int {
} }
return -1 return -1
} }
// loadPreviousConfig loads the previous port forwarding configuration for this host
func (m *portForwardModel) loadPreviousConfig() {
if m.historyManager == nil {
m.inputs[pfTypeInput].SetValue("Local (-L)")
return
}
config := m.historyManager.GetPortForwardingConfig(m.hostName)
if config == nil {
m.inputs[pfTypeInput].SetValue("Local (-L)")
return
}
// Set forward type based on saved configuration
switch config.Type {
case "local":
m.forwardType = LocalForward
case "remote":
m.forwardType = RemoteForward
case "dynamic":
m.forwardType = DynamicForward
default:
m.forwardType = LocalForward
}
m.inputs[pfTypeInput].SetValue(m.forwardType.String())
// Set values from saved configuration
if config.LocalPort != "" {
m.inputs[pfLocalPortInput].SetValue(config.LocalPort)
}
if config.RemoteHost != "" {
m.inputs[pfRemoteHostInput].SetValue(config.RemoteHost)
} else if m.forwardType != DynamicForward {
// Default to localhost for local and remote forwarding if not set
m.inputs[pfRemoteHostInput].SetValue("localhost")
}
if config.RemotePort != "" {
m.inputs[pfRemotePortInput].SetValue(config.RemotePort)
}
if config.BindAddress != "" {
m.inputs[pfBindAddressInput].SetValue(config.BindAddress)
}
}

View File

@ -672,7 +672,7 @@ func (m Model) handleListViewKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
selected := m.table.SelectedRow() selected := m.table.SelectedRow()
if len(selected) > 0 { if len(selected) > 0 {
hostName := extractHostNameFromTableRow(selected[0]) // Extract hostname from first column hostName := extractHostNameFromTableRow(selected[0]) // Extract hostname from first column
m.portForwardForm = NewPortForwardForm(hostName, m.styles, m.width, m.height, m.configFile) m.portForwardForm = NewPortForwardForm(hostName, m.styles, m.width, m.height, m.configFile, m.historyManager)
m.viewMode = ViewPortForward m.viewMode = ViewPortForward
return m, textinput.Blink return m, textinput.Blink
} }