feat: add port forwarding history persistence

This commit is contained in:
Gu1llaum-3 2025-09-12 11:29:17 +02:00
parent 71bf8ea2bb
commit 3c627a5d21
4 changed files with 334 additions and 37 deletions

View File

@ -15,11 +15,21 @@ type ConnectionHistory struct {
Connections map[string]ConnectionInfo `json:"connections"`
}
// PortForwardConfig stores port forwarding configuration
type PortForwardConfig struct {
Type string `json:"type"` // "local", "remote", "dynamic"
LocalPort string `json:"local_port"`
RemoteHost string `json:"remote_host"`
RemotePort string `json:"remote_port"`
BindAddress string `json:"bind_address"`
}
// ConnectionInfo stores information about a specific connection
type ConnectionInfo struct {
HostName string `json:"host_name"`
LastConnect time.Time `json:"last_connect"`
ConnectCount int `json:"connect_count"`
HostName string `json:"host_name"`
LastConnect time.Time `json:"last_connect"`
ConnectCount int `json:"connect_count"`
PortForwarding *PortForwardConfig `json:"port_forwarding,omitempty"`
}
// HistoryManager manages the connection history
@ -257,3 +267,42 @@ func (hm *HistoryManager) GetAllConnectionsInfo() []ConnectionInfo {
return connections
}
// RecordPortForwarding saves port forwarding configuration for a host
func (hm *HistoryManager) RecordPortForwarding(hostName, forwardType, localPort, remoteHost, remotePort, bindAddress string) error {
now := time.Now()
portForwardConfig := &PortForwardConfig{
Type: forwardType,
LocalPort: localPort,
RemoteHost: remoteHost,
RemotePort: remotePort,
BindAddress: bindAddress,
}
if conn, exists := hm.history.Connections[hostName]; exists {
// Update existing connection
conn.LastConnect = now
conn.ConnectCount++
conn.PortForwarding = portForwardConfig
hm.history.Connections[hostName] = conn
} else {
// Create new connection record
hm.history.Connections[hostName] = ConnectionInfo{
HostName: hostName,
LastConnect: now,
ConnectCount: 1,
PortForwarding: portForwardConfig,
}
}
return hm.saveHistory()
}
// GetPortForwardingConfig retrieves the last used port forwarding configuration for a host
func (hm *HistoryManager) GetPortForwardingConfig(hostName string) *PortForwardConfig {
if conn, exists := hm.history.Connections[hostName]; exists {
return conn.PortForwarding
}
return nil
}

View File

@ -0,0 +1,183 @@
package history
import (
"os"
"path/filepath"
"testing"
)
func TestPortForwardingHistory(t *testing.T) {
// Create temporary directory for testing
tempDir, err := os.MkdirTemp("", "sshm_test_*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
// Create history manager with temp directory
historyPath := filepath.Join(tempDir, "test_history.json")
hm := &HistoryManager{
historyPath: historyPath,
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
}
hostName := "test-server"
// Test recording port forwarding configuration
err = hm.RecordPortForwarding(hostName, "local", "8080", "localhost", "80", "127.0.0.1")
if err != nil {
t.Fatalf("Failed to record port forwarding: %v", err)
}
// Test retrieving port forwarding configuration
config := hm.GetPortForwardingConfig(hostName)
if config == nil {
t.Fatalf("Expected port forwarding config to exist")
}
// Verify the saved configuration
if config.Type != "local" {
t.Errorf("Expected Type 'local', got %s", config.Type)
}
if config.LocalPort != "8080" {
t.Errorf("Expected LocalPort '8080', got %s", config.LocalPort)
}
if config.RemoteHost != "localhost" {
t.Errorf("Expected RemoteHost 'localhost', got %s", config.RemoteHost)
}
if config.RemotePort != "80" {
t.Errorf("Expected RemotePort '80', got %s", config.RemotePort)
}
if config.BindAddress != "127.0.0.1" {
t.Errorf("Expected BindAddress '127.0.0.1', got %s", config.BindAddress)
}
// Test updating configuration with different values
err = hm.RecordPortForwarding(hostName, "remote", "3000", "app-server", "8000", "")
if err != nil {
t.Fatalf("Failed to record updated port forwarding: %v", err)
}
// Verify the updated configuration
config = hm.GetPortForwardingConfig(hostName)
if config == nil {
t.Fatalf("Expected port forwarding config to exist after update")
}
if config.Type != "remote" {
t.Errorf("Expected updated Type 'remote', got %s", config.Type)
}
if config.LocalPort != "3000" {
t.Errorf("Expected updated LocalPort '3000', got %s", config.LocalPort)
}
if config.RemoteHost != "app-server" {
t.Errorf("Expected updated RemoteHost 'app-server', got %s", config.RemoteHost)
}
if config.RemotePort != "8000" {
t.Errorf("Expected updated RemotePort '8000', got %s", config.RemotePort)
}
if config.BindAddress != "" {
t.Errorf("Expected updated BindAddress to be empty, got %s", config.BindAddress)
}
// Test dynamic forwarding
err = hm.RecordPortForwarding(hostName, "dynamic", "1080", "", "", "0.0.0.0")
if err != nil {
t.Fatalf("Failed to record dynamic port forwarding: %v", err)
}
config = hm.GetPortForwardingConfig(hostName)
if config == nil {
t.Fatalf("Expected port forwarding config to exist for dynamic forwarding")
}
if config.Type != "dynamic" {
t.Errorf("Expected Type 'dynamic', got %s", config.Type)
}
if config.LocalPort != "1080" {
t.Errorf("Expected LocalPort '1080', got %s", config.LocalPort)
}
if config.RemoteHost != "" {
t.Errorf("Expected RemoteHost to be empty for dynamic forwarding, got %s", config.RemoteHost)
}
if config.RemotePort != "" {
t.Errorf("Expected RemotePort to be empty for dynamic forwarding, got %s", config.RemotePort)
}
if config.BindAddress != "0.0.0.0" {
t.Errorf("Expected BindAddress '0.0.0.0', got %s", config.BindAddress)
}
}
func TestPortForwardingHistoryPersistence(t *testing.T) {
// Create temporary directory for testing
tempDir, err := os.MkdirTemp("", "sshm_test_*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
historyPath := filepath.Join(tempDir, "test_history.json")
// Create first history manager and record data
hm1 := &HistoryManager{
historyPath: historyPath,
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
}
hostName := "persistent-server"
err = hm1.RecordPortForwarding(hostName, "local", "9090", "db-server", "5432", "")
if err != nil {
t.Fatalf("Failed to record port forwarding: %v", err)
}
// Create second history manager and load data
hm2 := &HistoryManager{
historyPath: historyPath,
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
}
err = hm2.loadHistory()
if err != nil {
t.Fatalf("Failed to load history: %v", err)
}
// Verify the loaded configuration
config := hm2.GetPortForwardingConfig(hostName)
if config == nil {
t.Fatalf("Expected port forwarding config to be loaded from file")
}
if config.Type != "local" {
t.Errorf("Expected loaded Type 'local', got %s", config.Type)
}
if config.LocalPort != "9090" {
t.Errorf("Expected loaded LocalPort '9090', got %s", config.LocalPort)
}
if config.RemoteHost != "db-server" {
t.Errorf("Expected loaded RemoteHost 'db-server', got %s", config.RemoteHost)
}
if config.RemotePort != "5432" {
t.Errorf("Expected loaded RemotePort '5432', got %s", config.RemotePort)
}
}
func TestGetPortForwardingConfigNonExistent(t *testing.T) {
// Create temporary directory for testing
tempDir, err := os.MkdirTemp("", "sshm_test_*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
historyPath := filepath.Join(tempDir, "test_history.json")
hm := &HistoryManager{
historyPath: historyPath,
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
}
// Test getting configuration for non-existent host
config := hm.GetPortForwardingConfig("non-existent-host")
if config != nil {
t.Errorf("Expected nil config for non-existent host, got %+v", config)
}
}

View File

@ -5,6 +5,7 @@ import (
"strconv"
"strings"
"github.com/Gu1llaum-3/sshm/internal/history"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
@ -20,15 +21,16 @@ const (
)
type portForwardModel struct {
inputs []textinput.Model
focused int
forwardType PortForwardType
hostName string
err string
styles Styles
width int
height int
configFile string
inputs []textinput.Model
focused int
forwardType PortForwardType
hostName string
err string
styles Styles
width int
height int
configFile string
historyManager *history.HistoryManager
}
// portForwardSubmitMsg is sent when the port forward form is submitted
@ -41,7 +43,7 @@ type portForwardSubmitMsg struct {
type portForwardCancelMsg struct{}
// NewPortForwardForm creates a new port forward form model
func NewPortForwardForm(hostName string, styles Styles, width, height int, configFile string) *portForwardModel {
func NewPortForwardForm(hostName string, styles Styles, width, height int, configFile string, historyManager *history.HistoryManager) *portForwardModel {
inputs := make([]textinput.Model, 5)
// Forward type input (display only, controlled by arrow keys)
@ -49,7 +51,6 @@ func NewPortForwardForm(hostName string, styles Styles, width, height int, confi
inputs[pfTypeInput].Placeholder = "Use ←/→ to change forward type"
inputs[pfTypeInput].Focus()
inputs[pfTypeInput].Width = 40
inputs[pfTypeInput].SetValue("Local (-L)")
// Local port input
inputs[pfLocalPortInput] = textinput.New()
@ -77,16 +78,20 @@ func NewPortForwardForm(hostName string, styles Styles, width, height int, confi
inputs[pfBindAddressInput].Width = 30
pf := &portForwardModel{
inputs: inputs,
focused: 0,
forwardType: LocalForward,
hostName: hostName,
styles: styles,
width: width,
height: height,
configFile: configFile,
inputs: inputs,
focused: 0,
forwardType: LocalForward,
hostName: hostName,
styles: styles,
width: width,
height: height,
configFile: configFile,
historyManager: historyManager,
}
// Load previous port forwarding configuration if available
pf.loadPreviousConfig()
// Initialize input visibility
pf.updateInputVisibility()
@ -370,6 +375,11 @@ func (m *portForwardModel) submitForm() tea.Cmd {
return portForwardSubmitMsg{err: fmt.Errorf("invalid port number"), sshArgs: nil}
}
// Get form values for saving to history
remoteHost := strings.TrimSpace(m.inputs[pfRemoteHostInput].Value())
remotePort := strings.TrimSpace(m.inputs[pfRemotePortInput].Value())
bindAddress := strings.TrimSpace(m.inputs[pfBindAddressInput].Value())
// Build SSH command with port forwarding
var sshArgs []string
@ -379,13 +389,10 @@ func (m *portForwardModel) submitForm() tea.Cmd {
}
// Add forwarding arguments
bindAddress := strings.TrimSpace(m.inputs[pfBindAddressInput].Value())
var forwardTypeStr string
switch m.forwardType {
case LocalForward:
remoteHost := strings.TrimSpace(m.inputs[pfRemoteHostInput].Value())
remotePort := strings.TrimSpace(m.inputs[pfRemotePortInput].Value())
forwardTypeStr = "local"
if remoteHost == "" {
remoteHost = "localhost"
}
@ -408,31 +415,30 @@ func (m *portForwardModel) submitForm() tea.Cmd {
sshArgs = append(sshArgs, "-L", forwardArg)
case RemoteForward:
localHost := strings.TrimSpace(m.inputs[pfRemoteHostInput].Value())
localPortStr := strings.TrimSpace(m.inputs[pfRemotePortInput].Value())
if localHost == "" {
localHost = "localhost"
forwardTypeStr = "remote"
if remoteHost == "" {
remoteHost = "localhost"
}
if localPortStr == "" {
if remotePort == "" {
return portForwardSubmitMsg{err: fmt.Errorf("local port is required for remote forwarding"), sshArgs: nil}
}
// Validate local port
if _, err := strconv.Atoi(localPortStr); err != nil {
if _, err := strconv.Atoi(remotePort); err != nil {
return portForwardSubmitMsg{err: fmt.Errorf("invalid local port number"), sshArgs: nil}
}
// Build -R argument (note: localPort is actually the remote port in this context)
var forwardArg string
if bindAddress != "" {
forwardArg = fmt.Sprintf("%s:%s:%s:%s", bindAddress, localPort, localHost, localPortStr)
forwardArg = fmt.Sprintf("%s:%s:%s:%s", bindAddress, localPort, remoteHost, remotePort)
} else {
forwardArg = fmt.Sprintf("%s:%s:%s", localPort, localHost, localPortStr)
forwardArg = fmt.Sprintf("%s:%s:%s", localPort, remoteHost, remotePort)
}
sshArgs = append(sshArgs, "-R", forwardArg)
case DynamicForward:
forwardTypeStr = "dynamic"
// Build -D argument
var forwardArg string
if bindAddress != "" {
@ -443,6 +449,21 @@ func (m *portForwardModel) submitForm() tea.Cmd {
sshArgs = append(sshArgs, "-D", forwardArg)
}
// Save port forwarding configuration to history
if m.historyManager != nil {
if err := m.historyManager.RecordPortForwarding(
m.hostName,
forwardTypeStr,
localPort,
remoteHost,
remotePort,
bindAddress,
); err != nil {
// Log the error but don't fail the connection
// In a production environment, you might want to handle this differently
}
}
// Add hostname
sshArgs = append(sshArgs, m.hostName)
@ -488,3 +509,47 @@ func (m *portForwardModel) getPrevValidField(currentField int) int {
}
return -1
}
// loadPreviousConfig loads the previous port forwarding configuration for this host
func (m *portForwardModel) loadPreviousConfig() {
if m.historyManager == nil {
m.inputs[pfTypeInput].SetValue("Local (-L)")
return
}
config := m.historyManager.GetPortForwardingConfig(m.hostName)
if config == nil {
m.inputs[pfTypeInput].SetValue("Local (-L)")
return
}
// Set forward type based on saved configuration
switch config.Type {
case "local":
m.forwardType = LocalForward
case "remote":
m.forwardType = RemoteForward
case "dynamic":
m.forwardType = DynamicForward
default:
m.forwardType = LocalForward
}
m.inputs[pfTypeInput].SetValue(m.forwardType.String())
// Set values from saved configuration
if config.LocalPort != "" {
m.inputs[pfLocalPortInput].SetValue(config.LocalPort)
}
if config.RemoteHost != "" {
m.inputs[pfRemoteHostInput].SetValue(config.RemoteHost)
} else if m.forwardType != DynamicForward {
// Default to localhost for local and remote forwarding if not set
m.inputs[pfRemoteHostInput].SetValue("localhost")
}
if config.RemotePort != "" {
m.inputs[pfRemotePortInput].SetValue(config.RemotePort)
}
if config.BindAddress != "" {
m.inputs[pfBindAddressInput].SetValue(config.BindAddress)
}
}

View File

@ -672,7 +672,7 @@ func (m Model) handleListViewKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
selected := m.table.SelectedRow()
if len(selected) > 0 {
hostName := extractHostNameFromTableRow(selected[0]) // Extract hostname from first column
m.portForwardForm = NewPortForwardForm(hostName, m.styles, m.width, m.height, m.configFile)
m.portForwardForm = NewPortForwardForm(hostName, m.styles, m.width, m.height, m.configFile, m.historyManager)
m.viewMode = ViewPortForward
return m, textinput.Blink
}