mirror of
https://github.com/Gu1llaum-3/sshm.git
synced 2025-10-19 01:17:20 +02:00
feat: add port forwarding history persistence
This commit is contained in:
parent
71bf8ea2bb
commit
3c627a5d21
@ -15,11 +15,21 @@ type ConnectionHistory struct {
|
||||
Connections map[string]ConnectionInfo `json:"connections"`
|
||||
}
|
||||
|
||||
// PortForwardConfig stores port forwarding configuration
|
||||
type PortForwardConfig struct {
|
||||
Type string `json:"type"` // "local", "remote", "dynamic"
|
||||
LocalPort string `json:"local_port"`
|
||||
RemoteHost string `json:"remote_host"`
|
||||
RemotePort string `json:"remote_port"`
|
||||
BindAddress string `json:"bind_address"`
|
||||
}
|
||||
|
||||
// ConnectionInfo stores information about a specific connection
|
||||
type ConnectionInfo struct {
|
||||
HostName string `json:"host_name"`
|
||||
LastConnect time.Time `json:"last_connect"`
|
||||
ConnectCount int `json:"connect_count"`
|
||||
HostName string `json:"host_name"`
|
||||
LastConnect time.Time `json:"last_connect"`
|
||||
ConnectCount int `json:"connect_count"`
|
||||
PortForwarding *PortForwardConfig `json:"port_forwarding,omitempty"`
|
||||
}
|
||||
|
||||
// HistoryManager manages the connection history
|
||||
@ -257,3 +267,42 @@ func (hm *HistoryManager) GetAllConnectionsInfo() []ConnectionInfo {
|
||||
|
||||
return connections
|
||||
}
|
||||
|
||||
// RecordPortForwarding saves port forwarding configuration for a host
|
||||
func (hm *HistoryManager) RecordPortForwarding(hostName, forwardType, localPort, remoteHost, remotePort, bindAddress string) error {
|
||||
now := time.Now()
|
||||
|
||||
portForwardConfig := &PortForwardConfig{
|
||||
Type: forwardType,
|
||||
LocalPort: localPort,
|
||||
RemoteHost: remoteHost,
|
||||
RemotePort: remotePort,
|
||||
BindAddress: bindAddress,
|
||||
}
|
||||
|
||||
if conn, exists := hm.history.Connections[hostName]; exists {
|
||||
// Update existing connection
|
||||
conn.LastConnect = now
|
||||
conn.ConnectCount++
|
||||
conn.PortForwarding = portForwardConfig
|
||||
hm.history.Connections[hostName] = conn
|
||||
} else {
|
||||
// Create new connection record
|
||||
hm.history.Connections[hostName] = ConnectionInfo{
|
||||
HostName: hostName,
|
||||
LastConnect: now,
|
||||
ConnectCount: 1,
|
||||
PortForwarding: portForwardConfig,
|
||||
}
|
||||
}
|
||||
|
||||
return hm.saveHistory()
|
||||
}
|
||||
|
||||
// GetPortForwardingConfig retrieves the last used port forwarding configuration for a host
|
||||
func (hm *HistoryManager) GetPortForwardingConfig(hostName string) *PortForwardConfig {
|
||||
if conn, exists := hm.history.Connections[hostName]; exists {
|
||||
return conn.PortForwarding
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
183
internal/history/port_forward_test.go
Normal file
183
internal/history/port_forward_test.go
Normal file
@ -0,0 +1,183 @@
|
||||
package history
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPortForwardingHistory(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "sshm_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create history manager with temp directory
|
||||
historyPath := filepath.Join(tempDir, "test_history.json")
|
||||
hm := &HistoryManager{
|
||||
historyPath: historyPath,
|
||||
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
|
||||
}
|
||||
|
||||
hostName := "test-server"
|
||||
|
||||
// Test recording port forwarding configuration
|
||||
err = hm.RecordPortForwarding(hostName, "local", "8080", "localhost", "80", "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to record port forwarding: %v", err)
|
||||
}
|
||||
|
||||
// Test retrieving port forwarding configuration
|
||||
config := hm.GetPortForwardingConfig(hostName)
|
||||
if config == nil {
|
||||
t.Fatalf("Expected port forwarding config to exist")
|
||||
}
|
||||
|
||||
// Verify the saved configuration
|
||||
if config.Type != "local" {
|
||||
t.Errorf("Expected Type 'local', got %s", config.Type)
|
||||
}
|
||||
if config.LocalPort != "8080" {
|
||||
t.Errorf("Expected LocalPort '8080', got %s", config.LocalPort)
|
||||
}
|
||||
if config.RemoteHost != "localhost" {
|
||||
t.Errorf("Expected RemoteHost 'localhost', got %s", config.RemoteHost)
|
||||
}
|
||||
if config.RemotePort != "80" {
|
||||
t.Errorf("Expected RemotePort '80', got %s", config.RemotePort)
|
||||
}
|
||||
if config.BindAddress != "127.0.0.1" {
|
||||
t.Errorf("Expected BindAddress '127.0.0.1', got %s", config.BindAddress)
|
||||
}
|
||||
|
||||
// Test updating configuration with different values
|
||||
err = hm.RecordPortForwarding(hostName, "remote", "3000", "app-server", "8000", "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to record updated port forwarding: %v", err)
|
||||
}
|
||||
|
||||
// Verify the updated configuration
|
||||
config = hm.GetPortForwardingConfig(hostName)
|
||||
if config == nil {
|
||||
t.Fatalf("Expected port forwarding config to exist after update")
|
||||
}
|
||||
|
||||
if config.Type != "remote" {
|
||||
t.Errorf("Expected updated Type 'remote', got %s", config.Type)
|
||||
}
|
||||
if config.LocalPort != "3000" {
|
||||
t.Errorf("Expected updated LocalPort '3000', got %s", config.LocalPort)
|
||||
}
|
||||
if config.RemoteHost != "app-server" {
|
||||
t.Errorf("Expected updated RemoteHost 'app-server', got %s", config.RemoteHost)
|
||||
}
|
||||
if config.RemotePort != "8000" {
|
||||
t.Errorf("Expected updated RemotePort '8000', got %s", config.RemotePort)
|
||||
}
|
||||
if config.BindAddress != "" {
|
||||
t.Errorf("Expected updated BindAddress to be empty, got %s", config.BindAddress)
|
||||
}
|
||||
|
||||
// Test dynamic forwarding
|
||||
err = hm.RecordPortForwarding(hostName, "dynamic", "1080", "", "", "0.0.0.0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to record dynamic port forwarding: %v", err)
|
||||
}
|
||||
|
||||
config = hm.GetPortForwardingConfig(hostName)
|
||||
if config == nil {
|
||||
t.Fatalf("Expected port forwarding config to exist for dynamic forwarding")
|
||||
}
|
||||
|
||||
if config.Type != "dynamic" {
|
||||
t.Errorf("Expected Type 'dynamic', got %s", config.Type)
|
||||
}
|
||||
if config.LocalPort != "1080" {
|
||||
t.Errorf("Expected LocalPort '1080', got %s", config.LocalPort)
|
||||
}
|
||||
if config.RemoteHost != "" {
|
||||
t.Errorf("Expected RemoteHost to be empty for dynamic forwarding, got %s", config.RemoteHost)
|
||||
}
|
||||
if config.RemotePort != "" {
|
||||
t.Errorf("Expected RemotePort to be empty for dynamic forwarding, got %s", config.RemotePort)
|
||||
}
|
||||
if config.BindAddress != "0.0.0.0" {
|
||||
t.Errorf("Expected BindAddress '0.0.0.0', got %s", config.BindAddress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForwardingHistoryPersistence(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "sshm_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
historyPath := filepath.Join(tempDir, "test_history.json")
|
||||
|
||||
// Create first history manager and record data
|
||||
hm1 := &HistoryManager{
|
||||
historyPath: historyPath,
|
||||
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
|
||||
}
|
||||
|
||||
hostName := "persistent-server"
|
||||
err = hm1.RecordPortForwarding(hostName, "local", "9090", "db-server", "5432", "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to record port forwarding: %v", err)
|
||||
}
|
||||
|
||||
// Create second history manager and load data
|
||||
hm2 := &HistoryManager{
|
||||
historyPath: historyPath,
|
||||
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
|
||||
}
|
||||
|
||||
err = hm2.loadHistory()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load history: %v", err)
|
||||
}
|
||||
|
||||
// Verify the loaded configuration
|
||||
config := hm2.GetPortForwardingConfig(hostName)
|
||||
if config == nil {
|
||||
t.Fatalf("Expected port forwarding config to be loaded from file")
|
||||
}
|
||||
|
||||
if config.Type != "local" {
|
||||
t.Errorf("Expected loaded Type 'local', got %s", config.Type)
|
||||
}
|
||||
if config.LocalPort != "9090" {
|
||||
t.Errorf("Expected loaded LocalPort '9090', got %s", config.LocalPort)
|
||||
}
|
||||
if config.RemoteHost != "db-server" {
|
||||
t.Errorf("Expected loaded RemoteHost 'db-server', got %s", config.RemoteHost)
|
||||
}
|
||||
if config.RemotePort != "5432" {
|
||||
t.Errorf("Expected loaded RemotePort '5432', got %s", config.RemotePort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPortForwardingConfigNonExistent(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "sshm_test_*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
historyPath := filepath.Join(tempDir, "test_history.json")
|
||||
hm := &HistoryManager{
|
||||
historyPath: historyPath,
|
||||
history: &ConnectionHistory{Connections: make(map[string]ConnectionInfo)},
|
||||
}
|
||||
|
||||
// Test getting configuration for non-existent host
|
||||
config := hm.GetPortForwardingConfig("non-existent-host")
|
||||
if config != nil {
|
||||
t.Errorf("Expected nil config for non-existent host, got %+v", config)
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Gu1llaum-3/sshm/internal/history"
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
@ -20,15 +21,16 @@ const (
|
||||
)
|
||||
|
||||
type portForwardModel struct {
|
||||
inputs []textinput.Model
|
||||
focused int
|
||||
forwardType PortForwardType
|
||||
hostName string
|
||||
err string
|
||||
styles Styles
|
||||
width int
|
||||
height int
|
||||
configFile string
|
||||
inputs []textinput.Model
|
||||
focused int
|
||||
forwardType PortForwardType
|
||||
hostName string
|
||||
err string
|
||||
styles Styles
|
||||
width int
|
||||
height int
|
||||
configFile string
|
||||
historyManager *history.HistoryManager
|
||||
}
|
||||
|
||||
// portForwardSubmitMsg is sent when the port forward form is submitted
|
||||
@ -41,7 +43,7 @@ type portForwardSubmitMsg struct {
|
||||
type portForwardCancelMsg struct{}
|
||||
|
||||
// NewPortForwardForm creates a new port forward form model
|
||||
func NewPortForwardForm(hostName string, styles Styles, width, height int, configFile string) *portForwardModel {
|
||||
func NewPortForwardForm(hostName string, styles Styles, width, height int, configFile string, historyManager *history.HistoryManager) *portForwardModel {
|
||||
inputs := make([]textinput.Model, 5)
|
||||
|
||||
// Forward type input (display only, controlled by arrow keys)
|
||||
@ -49,7 +51,6 @@ func NewPortForwardForm(hostName string, styles Styles, width, height int, confi
|
||||
inputs[pfTypeInput].Placeholder = "Use ←/→ to change forward type"
|
||||
inputs[pfTypeInput].Focus()
|
||||
inputs[pfTypeInput].Width = 40
|
||||
inputs[pfTypeInput].SetValue("Local (-L)")
|
||||
|
||||
// Local port input
|
||||
inputs[pfLocalPortInput] = textinput.New()
|
||||
@ -77,16 +78,20 @@ func NewPortForwardForm(hostName string, styles Styles, width, height int, confi
|
||||
inputs[pfBindAddressInput].Width = 30
|
||||
|
||||
pf := &portForwardModel{
|
||||
inputs: inputs,
|
||||
focused: 0,
|
||||
forwardType: LocalForward,
|
||||
hostName: hostName,
|
||||
styles: styles,
|
||||
width: width,
|
||||
height: height,
|
||||
configFile: configFile,
|
||||
inputs: inputs,
|
||||
focused: 0,
|
||||
forwardType: LocalForward,
|
||||
hostName: hostName,
|
||||
styles: styles,
|
||||
width: width,
|
||||
height: height,
|
||||
configFile: configFile,
|
||||
historyManager: historyManager,
|
||||
}
|
||||
|
||||
// Load previous port forwarding configuration if available
|
||||
pf.loadPreviousConfig()
|
||||
|
||||
// Initialize input visibility
|
||||
pf.updateInputVisibility()
|
||||
|
||||
@ -370,6 +375,11 @@ func (m *portForwardModel) submitForm() tea.Cmd {
|
||||
return portForwardSubmitMsg{err: fmt.Errorf("invalid port number"), sshArgs: nil}
|
||||
}
|
||||
|
||||
// Get form values for saving to history
|
||||
remoteHost := strings.TrimSpace(m.inputs[pfRemoteHostInput].Value())
|
||||
remotePort := strings.TrimSpace(m.inputs[pfRemotePortInput].Value())
|
||||
bindAddress := strings.TrimSpace(m.inputs[pfBindAddressInput].Value())
|
||||
|
||||
// Build SSH command with port forwarding
|
||||
var sshArgs []string
|
||||
|
||||
@ -379,13 +389,10 @@ func (m *portForwardModel) submitForm() tea.Cmd {
|
||||
}
|
||||
|
||||
// Add forwarding arguments
|
||||
bindAddress := strings.TrimSpace(m.inputs[pfBindAddressInput].Value())
|
||||
|
||||
var forwardTypeStr string
|
||||
switch m.forwardType {
|
||||
case LocalForward:
|
||||
remoteHost := strings.TrimSpace(m.inputs[pfRemoteHostInput].Value())
|
||||
remotePort := strings.TrimSpace(m.inputs[pfRemotePortInput].Value())
|
||||
|
||||
forwardTypeStr = "local"
|
||||
if remoteHost == "" {
|
||||
remoteHost = "localhost"
|
||||
}
|
||||
@ -408,31 +415,30 @@ func (m *portForwardModel) submitForm() tea.Cmd {
|
||||
sshArgs = append(sshArgs, "-L", forwardArg)
|
||||
|
||||
case RemoteForward:
|
||||
localHost := strings.TrimSpace(m.inputs[pfRemoteHostInput].Value())
|
||||
localPortStr := strings.TrimSpace(m.inputs[pfRemotePortInput].Value())
|
||||
|
||||
if localHost == "" {
|
||||
localHost = "localhost"
|
||||
forwardTypeStr = "remote"
|
||||
if remoteHost == "" {
|
||||
remoteHost = "localhost"
|
||||
}
|
||||
if localPortStr == "" {
|
||||
if remotePort == "" {
|
||||
return portForwardSubmitMsg{err: fmt.Errorf("local port is required for remote forwarding"), sshArgs: nil}
|
||||
}
|
||||
|
||||
// Validate local port
|
||||
if _, err := strconv.Atoi(localPortStr); err != nil {
|
||||
if _, err := strconv.Atoi(remotePort); err != nil {
|
||||
return portForwardSubmitMsg{err: fmt.Errorf("invalid local port number"), sshArgs: nil}
|
||||
}
|
||||
|
||||
// Build -R argument (note: localPort is actually the remote port in this context)
|
||||
var forwardArg string
|
||||
if bindAddress != "" {
|
||||
forwardArg = fmt.Sprintf("%s:%s:%s:%s", bindAddress, localPort, localHost, localPortStr)
|
||||
forwardArg = fmt.Sprintf("%s:%s:%s:%s", bindAddress, localPort, remoteHost, remotePort)
|
||||
} else {
|
||||
forwardArg = fmt.Sprintf("%s:%s:%s", localPort, localHost, localPortStr)
|
||||
forwardArg = fmt.Sprintf("%s:%s:%s", localPort, remoteHost, remotePort)
|
||||
}
|
||||
sshArgs = append(sshArgs, "-R", forwardArg)
|
||||
|
||||
case DynamicForward:
|
||||
forwardTypeStr = "dynamic"
|
||||
// Build -D argument
|
||||
var forwardArg string
|
||||
if bindAddress != "" {
|
||||
@ -443,6 +449,21 @@ func (m *portForwardModel) submitForm() tea.Cmd {
|
||||
sshArgs = append(sshArgs, "-D", forwardArg)
|
||||
}
|
||||
|
||||
// Save port forwarding configuration to history
|
||||
if m.historyManager != nil {
|
||||
if err := m.historyManager.RecordPortForwarding(
|
||||
m.hostName,
|
||||
forwardTypeStr,
|
||||
localPort,
|
||||
remoteHost,
|
||||
remotePort,
|
||||
bindAddress,
|
||||
); err != nil {
|
||||
// Log the error but don't fail the connection
|
||||
// In a production environment, you might want to handle this differently
|
||||
}
|
||||
}
|
||||
|
||||
// Add hostname
|
||||
sshArgs = append(sshArgs, m.hostName)
|
||||
|
||||
@ -488,3 +509,47 @@ func (m *portForwardModel) getPrevValidField(currentField int) int {
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// loadPreviousConfig loads the previous port forwarding configuration for this host
|
||||
func (m *portForwardModel) loadPreviousConfig() {
|
||||
if m.historyManager == nil {
|
||||
m.inputs[pfTypeInput].SetValue("Local (-L)")
|
||||
return
|
||||
}
|
||||
|
||||
config := m.historyManager.GetPortForwardingConfig(m.hostName)
|
||||
if config == nil {
|
||||
m.inputs[pfTypeInput].SetValue("Local (-L)")
|
||||
return
|
||||
}
|
||||
|
||||
// Set forward type based on saved configuration
|
||||
switch config.Type {
|
||||
case "local":
|
||||
m.forwardType = LocalForward
|
||||
case "remote":
|
||||
m.forwardType = RemoteForward
|
||||
case "dynamic":
|
||||
m.forwardType = DynamicForward
|
||||
default:
|
||||
m.forwardType = LocalForward
|
||||
}
|
||||
m.inputs[pfTypeInput].SetValue(m.forwardType.String())
|
||||
|
||||
// Set values from saved configuration
|
||||
if config.LocalPort != "" {
|
||||
m.inputs[pfLocalPortInput].SetValue(config.LocalPort)
|
||||
}
|
||||
if config.RemoteHost != "" {
|
||||
m.inputs[pfRemoteHostInput].SetValue(config.RemoteHost)
|
||||
} else if m.forwardType != DynamicForward {
|
||||
// Default to localhost for local and remote forwarding if not set
|
||||
m.inputs[pfRemoteHostInput].SetValue("localhost")
|
||||
}
|
||||
if config.RemotePort != "" {
|
||||
m.inputs[pfRemotePortInput].SetValue(config.RemotePort)
|
||||
}
|
||||
if config.BindAddress != "" {
|
||||
m.inputs[pfBindAddressInput].SetValue(config.BindAddress)
|
||||
}
|
||||
}
|
||||
|
@ -672,7 +672,7 @@ func (m Model) handleListViewKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
selected := m.table.SelectedRow()
|
||||
if len(selected) > 0 {
|
||||
hostName := extractHostNameFromTableRow(selected[0]) // Extract hostname from first column
|
||||
m.portForwardForm = NewPortForwardForm(hostName, m.styles, m.width, m.height, m.configFile)
|
||||
m.portForwardForm = NewPortForwardForm(hostName, m.styles, m.width, m.height, m.configFile, m.historyManager)
|
||||
m.viewMode = ViewPortForward
|
||||
return m, textinput.Blink
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user