下面的代码是保存为ssh2http.go , 执行如下代码
go mod init example.com/m
go mod tidy
GOOS=linux GOARCH=amd64 go build ssh2http.go
# 可以通过 go tool dist list 获取
PLATFORMS=(
"linux/amd64"
"linux/arm64"
"darwin/amd64"
"darwin/arm64"
"windows/amd64"
)
# 运行
ssh2http --ssh-host=1.2.3.4:22 \
--ssh-key-passphrase=123456 \
--ssh-key=/Users/changhui.wy/.ssh/id_ed25519 \
--local=:38080
ssh2http --ssh-host=1.2.2.3:22 \
--ssh-user=root \
--ssh-password=123456 \
--local=:38080
源码如下:
package main
import (
"bufio"
"encoding/base64"
"encoding/binary"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
var (
//添加http代理的账密
httpProxyUser = flag.String("http-proxy-user", "", "HTTP proxy username")
httpProxyPass = flag.String("http-proxy-pass", "", "HTTP proxy password")
sshHost = flag.String("ssh-host", "", "SSH server address (e.g. 1.2.3.4:22)")
sshUser = flag.String("ssh-user", "root", "SSH username")
sshPassword = flag.String("ssh-password", "", "SSH password (optional if using private key)")
sshKeyFile = flag.String("ssh-key", "", "Path to private key file (e.g. id_rsa)")
localAddr = flag.String("local", ":8080", "Local HTTP proxy listen address (e.g. :8080)")
socks5Addr = flag.String("socks5", ":1080", "Local SOCKS5 proxy listen address (e.g. :1080)")
socks5User = flag.String("socks5-user", "", "SOCKS5 proxy username")
socks5Pass = flag.String("socks5-pass", "", "SOCKS5 proxy password")
reconnectSec = flag.Int("reconnect-interval", 5, "Reconnect interval in seconds after failure")
sshKeyPassphrase = flag.String("ssh-key-passphrase", "", "Passphrase for encrypted private key (optional)")
sshConfigFile = flag.String("ssh-config", "", "Path to SSH config file (default: ~/.ssh/config)")
sshConfigHost = flag.String("ssh-config-host", "", "Host alias from SSH config file to use")
enableHTTP = flag.Bool("enable-http", true, "Enable HTTP proxy server")
enableSOCKS5 = flag.Bool("enable-socks5", true, "Enable SOCKS5 proxy server")
)
// SSHConfig 存储从 ssh_config 解析的配置
type SSHConfig struct {
Host string
HostName string
Port string
User string
IdentityFile string
Password string // 支持密码认证
ProxyJump []string // 支持多级跳板
}
func main() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0])
fmt.Fprintln(os.Stderr, "\nOptions:")
flag.PrintDefaults()
fmt.Fprintln(os.Stderr, `
Examples:
# Direct connection with encrypted private key (both HTTP and SOCKS5)
ssh2http --ssh-host=1.2.3.4:22 --ssh-key=id_rsa --ssh-key-passphrase='mysecret' --http-proxy-user=admin --http-proxy-pass=admin --socks5-user=admin --socks5-pass=admin --local=:8080 --socks5=:1080
# Using SSH config file with jump servers
ssh2http --ssh-config=~/.ssh/config --ssh-config-host=production-server --http-proxy-user=admin --http-proxy-pass=admin --local=:8080
# Only SOCKS5 proxy (disable HTTP)
ssh2http --ssh-host=1.2.3.4:22 --ssh-user=root --ssh-password=mypass --enable-http=false --socks5=:1080
# Use for shell (HTTP proxy)
export https_proxy=http://admin:admin@localhost:8080
export http_proxy=http://admin:admin@localhost:8080
# Use for shell (SOCKS5 proxy)
export all_proxy=socks5://admin:admin@localhost:1080
SSH Config Example (~/.ssh/config):
# 使用密钥认证
Host jump1
HostName 1.2.3.4
User jumpuser
Port 22
IdentityFile ~/.ssh/jump1_key
# 使用密码认证(注意:这是自定义扩展,非标准 OpenSSH 配置)
Host jump2
HostName 5.6.7.8
User jumpuser2
Port 22
Password mypassword123
ProxyJump jump1
# 混合认证:跳板用密码,目标用密钥
Host production-server
HostName 192.168.1.100
User root
Port 22
IdentityFile ~/.ssh/prod_key
ProxyJump jump1,jump2
Notes:
- Either --ssh-config-host or --ssh-host must be provided.
- If using --ssh-config-host, --ssh-config will default to ~/.ssh/config
- ProxyJump supports multiple hops: jump1,jump2,jump3
- ProxyJump uses SSH -W (direct-tcpip channel) for standard SSH tunneling
- SOCKS5 proxy supports authentication (username/password)
`)
}
if len(os.Args) == 1 {
flag.Usage()
os.Exit(0)
}
flag.Parse()
// 如果使用 ssh-config-host 但没指定 ssh-config,使用默认路径
if *sshConfigHost != "" && *sshConfigFile == "" {
homeDir, err := os.UserHomeDir()
if err != nil {
log.Fatal("Cannot determine home directory:", err)
}
*sshConfigFile = filepath.Join(homeDir, ".ssh", "config")
}
if *sshHost == "" && *sshConfigHost == "" {
log.Fatal("Error: either --ssh-host or --ssh-config-host is required")
}
if !*enableHTTP && !*enableSOCKS5 {
log.Fatal("Error: at least one of --enable-http or --enable-socks5 must be true")
}
// 创建共享的 SSH 连接管理器
sshManager := &SSHManager{
sshHost: *sshHost,
sshUser: *sshUser,
sshPassword: *sshPassword,
sshKeyFile: parseShellPath(*sshKeyFile),
reconnectSec: time.Duration(*reconnectSec) * time.Second,
sshConfigFile: *sshConfigFile,
sshConfigHost: *sshConfigHost,
}
var wg sync.WaitGroup
// 启动 HTTP 代理服务器
if *enableHTTP {
wg.Add(1)
go func() {
defer wg.Done()
proxy := &HTTPProxy{sshManager: sshManager}
log.Printf("Starting HTTP proxy on %s", *localAddr)
if *sshConfigHost != "" {
log.Printf("Using SSH config host: %s from %s", *sshConfigHost, *sshConfigFile)
} else {
log.Printf("Direct connection to SSH server: %s", *sshHost)
}
err := http.ListenAndServe(*localAddr, proxy)
if err != nil {
log.Fatalf("Failed to start HTTP proxy: %v", err)
}
}()
}
// 启动 SOCKS5 代理服务器
if *enableSOCKS5 {
wg.Add(1)
go func() {
defer wg.Done()
socks5Server := &SOCKS5Server{
sshManager: sshManager,
username: *socks5User,
password: *socks5Pass,
}
log.Printf("Starting SOCKS5 proxy on %s", *socks5Addr)
err := socks5Server.ListenAndServe(*socks5Addr)
if err != nil {
log.Fatalf("Failed to start SOCKS5 proxy: %v", err)
}
}()
}
wg.Wait()
}
// SSHManager 管理 SSH 连接(HTTP 和 SOCKS5 共享)
type SSHManager struct {
sshHost string
sshUser string
sshPassword string
sshKeyFile string
sshConfigFile string
sshConfigHost string
reconnectSec time.Duration
mu sync.RWMutex
client *ssh.Client
jumpClients []*ssh.Client // 保存跳板机连接链
lastError error
clientHealthy bool
}
// HTTPProxy 实现 http.Handler,作为 HTTP 代理
type HTTPProxy struct {
sshManager *SSHManager
}
// SOCKS5Server SOCKS5 代理服务器
type SOCKS5Server struct {
sshManager *SSHManager
username string
password string
}
// SOCKS5 常量
const (
socks5Version = 0x05
socks5NoAuth = 0x00
socks5UserPass = 0x02
socks5NoAcceptable = 0xFF
socks5Connect = 0x01
socks5IPv4 = 0x01
socks5Domain = 0x03
socks5IPv6 = 0x04
socks5Success = 0x00
socks5GeneralFailure = 0x01
socks5ConnectionNotAllowed = 0x02
socks5NetworkUnreachable = 0x03
socks5HostUnreachable = 0x04
socks5ConnectionRefused = 0x05
socks5TTLExpired = 0x06
socks5CommandNotSupported = 0x07
socks5AddressNotSupported = 0x08
)
// ListenAndServe 启动 SOCKS5 服务器
func (s *SOCKS5Server) ListenAndServe(addr string) error {
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("SOCKS5 accept error: %v", err)
continue
}
go s.handleConnection(conn)
}
}
func (s *SOCKS5Server) handleConnection(conn net.Conn) {
defer conn.Close()
clientIP, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
// 1. 握手阶段 - 协商认证方法
if err := s.handshake(conn); err != nil {
log.Printf("[SOCKS5] Client %s handshake failed: %v", clientIP, err)
return
}
// 2. 认证阶段(如果需要)
if s.username != "" || s.password != "" {
if err := s.authenticate(conn); err != nil {
log.Printf("[SOCKS5] Client %s authentication failed: %v", clientIP, err)
return
}
}
// 3. 请求阶段
targetAddr, err := s.handleRequest(conn)
if err != nil {
log.Printf("[SOCKS5] Client %s request failed: %v", clientIP, err)
return
}
log.Printf("[SOCKS5] Client %s -> Target %s", clientIP, targetAddr)
// 4. 连接目标服务器
if err := s.connectAndRelay(conn, targetAddr); err != nil {
log.Printf("[SOCKS5] Client %s relay error: %v", clientIP, err)
}
}
func (s *SOCKS5Server) handshake(conn net.Conn) error {
// 读取客户端支持的认证方法
// +----+----------+----------+
// |VER | NMETHODS | METHODS |
// +----+----------+----------+
// | 1 | 1 | 1 to 255 |
// +----+----------+----------+
buf := make([]byte, 257)
n, err := io.ReadAtLeast(conn, buf, 2)
if err != nil {
return fmt.Errorf("read handshake: %w", err)
}
if buf[0] != socks5Version {
return fmt.Errorf("unsupported SOCKS version: %d", buf[0])
}
nmethods := int(buf[1])
if n < 2+nmethods {
_, err = io.ReadFull(conn, buf[n:2+nmethods])
if err != nil {
return fmt.Errorf("read methods: %w", err)
}
}
methods := buf[2 : 2+nmethods]
// 选择认证方法
var selectedMethod byte = socks5NoAcceptable
needAuth := s.username != "" || s.password != ""
for _, method := range methods {
if needAuth && method == socks5UserPass {
selectedMethod = socks5UserPass
break
} else if !needAuth && method == socks5NoAuth {
selectedMethod = socks5NoAuth
break
}
}
// 发送选择的认证方法
// +----+--------+
// |VER | METHOD |
// +----+--------+
// | 1 | 1 |
// +----+--------+
_, err = conn.Write([]byte{socks5Version, selectedMethod})
if err != nil {
return fmt.Errorf("write method selection: %w", err)
}
if selectedMethod == socks5NoAcceptable {
return fmt.Errorf("no acceptable authentication method")
}
return nil
}
func (s *SOCKS5Server) authenticate(conn net.Conn) error {
// 用户名/密码认证
// +----+------+----------+------+----------+
// |VER | ULEN | UNAME | PLEN | PASSWD |
// +----+------+----------+------+----------+
// | 1 | 1 | 1 to 255 | 1 | 1 to 255 |
// +----+------+----------+------+----------+
buf := make([]byte, 513)
n, err := io.ReadAtLeast(conn, buf, 2)
if err != nil {
return fmt.Errorf("read auth header: %w", err)
}
if buf[0] != 0x01 { // 用户名/密码认证版本
return fmt.Errorf("unsupported auth version: %d", buf[0])
}
ulen := int(buf[1])
if n < 2+ulen {
_, err = io.ReadFull(conn, buf[n:2+ulen])
if err != nil {
return fmt.Errorf("read username: %w", err)
}
n = 2 + ulen
}
username := string(buf[2 : 2+ulen])
if n < 3+ulen {
_, err = io.ReadFull(conn, buf[n:3+ulen])
if err != nil {
return fmt.Errorf("read password length: %w", err)
}
n = 3 + ulen
}
plen := int(buf[2+ulen])
if n < 3+ulen+plen {
_, err = io.ReadFull(conn, buf[n:3+ulen+plen])
if err != nil {
return fmt.Errorf("read password: %w", err)
}
}
password := string(buf[3+ulen : 3+ulen+plen])
// 验证用户名和密码
// +----+--------+
// |VER | STATUS |
// +----+--------+
// | 1 | 1 |
// +----+--------+
var status byte = 0x00 // 成功
if username != s.username || password != s.password {
status = 0x01 // 失败
}
_, err = conn.Write([]byte{0x01, status})
if err != nil {
return fmt.Errorf("write auth response: %w", err)
}
if status != 0x00 {
return fmt.Errorf("invalid credentials")
}
return nil
}
func (s *SOCKS5Server) handleRequest(conn net.Conn) (string, error) {
// 读取客户端请求
// +----+-----+-------+------+----------+----------+
// |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
// +----+-----+-------+------+----------+----------+
// | 1 | 1 | X'00' | 1 | Variable | 2 |
// +----+-----+-------+------+----------+----------+
buf := make([]byte, 263)
n, err := io.ReadAtLeast(conn, buf, 4)
if err != nil {
s.sendReply(conn, socks5GeneralFailure, nil)
return "", fmt.Errorf("read request header: %w", err)
}
if buf[0] != socks5Version {
s.sendReply(conn, socks5GeneralFailure, nil)
return "", fmt.Errorf("unsupported version: %d", buf[0])
}
cmd := buf[1]
if cmd != socks5Connect {
s.sendReply(conn, socks5CommandNotSupported, nil)
return "", fmt.Errorf("unsupported command: %d", cmd)
}
atyp := buf[3]
var targetAddr string
switch atyp {
case socks5IPv4:
// IPv4 地址:4 字节
if n < 4+net.IPv4len {
_, err = io.ReadFull(conn, buf[n:4+net.IPv4len])
if err != nil {
s.sendReply(conn, socks5GeneralFailure, nil)
return "", fmt.Errorf("read IPv4: %w", err)
}
n = 4 + net.IPv4len
}
ip := net.IP(buf[4 : 4+net.IPv4len])
if n < 4+net.IPv4len+2 {
_, err = io.ReadFull(conn, buf[n:4+net.IPv4len+2])
if err != nil {
s.sendReply(conn, socks5GeneralFailure, nil)
return "", fmt.Errorf("read port: %w", err)
}
}
port := binary.BigEndian.Uint16(buf[4+net.IPv4len : 4+net.IPv4len+2])
targetAddr = fmt.Sprintf("%s:%d", ip.String(), port)
case socks5Domain:
// 域名:1 字节长度 + 域名
if n < 5 {
_, err = io.ReadFull(conn, buf[n:5])
if err != nil {
s.sendReply(conn, socks5GeneralFailure, nil)
return "", fmt.Errorf("read domain length: %w", err)
}
n = 5
}
domainLen := int(buf[4])
if n < 5+domainLen {
_, err = io.ReadFull(conn, buf[n:5+domainLen])
if err != nil {
s.sendReply(conn, socks5GeneralFailure, nil)
return "", fmt.Errorf("read domain: %w", err)
}
n = 5 + domainLen
}
domain := string(buf[5 : 5+domainLen])
if n < 5+domainLen+2 {
_, err = io.ReadFull(conn, buf[n:5+domainLen+2])
if err != nil {
s.sendReply(conn, socks5GeneralFailure, nil)
return "", fmt.Errorf("read port: %w", err)
}
}
port := binary.BigEndian.Uint16(buf[5+domainLen : 5+domainLen+2])
targetAddr = fmt.Sprintf("%s:%d", domain, port)
case socks5IPv6:
// IPv6 地址:16 字节
if n < 4+net.IPv6len {
_, err = io.ReadFull(conn, buf[n:4+net.IPv6len])
if err != nil {
s.sendReply(conn, socks5GeneralFailure, nil)
return "", fmt.Errorf("read IPv6: %w", err)
}
n = 4 + net.IPv6len
}
ip := net.IP(buf[4 : 4+net.IPv6len])
if n < 4+net.IPv6len+2 {
_, err = io.ReadFull(conn, buf[n:4+net.IPv6len+2])
if err != nil {
s.sendReply(conn, socks5GeneralFailure, nil)
return "", fmt.Errorf("read port: %w", err)
}
}
port := binary.BigEndian.Uint16(buf[4+net.IPv6len : 4+net.IPv6len+2])
targetAddr = fmt.Sprintf("[%s]:%d", ip.String(), port)
default:
s.sendReply(conn, socks5AddressNotSupported, nil)
return "", fmt.Errorf("unsupported address type: %d", atyp)
}
return targetAddr, nil
}
func (s *SOCKS5Server) sendReply(conn net.Conn, rep byte, bindAddr net.Addr) error {
// +----+-----+-------+------+----------+----------+
// |VER | REP | RSV | ATYP | BND.ADDR | BND.PORT |
// +----+-----+-------+------+----------+----------+
// | 1 | 1 | X'00' | 1 | Variable | 2 |
// +----+-----+-------+------+----------+----------+
reply := []byte{socks5Version, rep, 0x00, 0x01}
if bindAddr != nil {
// 如果有绑定地址,使用它
host, portStr, _ := net.SplitHostPort(bindAddr.String())
ip := net.ParseIP(host)
if ip4 := ip.To4(); ip4 != nil {
reply[3] = socks5IPv4
reply = append(reply, ip4...)
} else {
reply[3] = socks5IPv6
reply = append(reply, ip...)
}
port := 0
fmt.Sscanf(portStr, "%d", &port)
portBytes := make([]byte, 2)
binary.BigEndian.PutUint16(portBytes, uint16(port))
reply = append(reply, portBytes...)
} else {
// 使用默认值 0.0.0.0:0
reply = append(reply, []byte{0, 0, 0, 0, 0, 0}...)
}
_, err := conn.Write(reply)
return err
}
func (s *SOCKS5Server) connectAndRelay(clientConn net.Conn, targetAddr string) error {
maxRetries := 3
var sshClient *ssh.Client
var targetConn net.Conn
var err error
for retry := 0; retry < maxRetries; retry++ {
sshClient, err = s.sshManager.getSSHClient()
if err != nil {
log.Printf("Failed to get SSH client for SOCKS5 (attempt %d/%d): %v", retry+1, maxRetries, err)
if retry < maxRetries-1 {
time.Sleep(time.Second * time.Duration(retry+1))
continue
}
s.sendReply(clientConn, socks5GeneralFailure, nil)
return fmt.Errorf("SSH unavailable: %w", err)
}
targetConn, err = sshClient.Dial("tcp", targetAddr)
if err != nil {
log.Printf("Failed to dial target %s via SSH (attempt %d/%d): %v", targetAddr, retry+1, maxRetries, err)
s.sshManager.markUnhealthy()
if retry < maxRetries-1 {
time.Sleep(time.Second * time.Duration(retry+1))
continue
}
s.sendReply(clientConn, socks5HostUnreachable, nil)
return fmt.Errorf("dial target: %w", err)
}
break
}
defer targetConn.Close()
// 发送成功响应
if err := s.sendReply(clientConn, socks5Success, targetConn.LocalAddr()); err != nil {
return fmt.Errorf("send reply: %w", err)
}
// 双向转发数据
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(targetConn, clientConn)
targetConn.Close()
}()
go func() {
defer wg.Done()
io.Copy(clientConn, targetConn)
clientConn.Close()
}()
wg.Wait()
return nil
}
// 解析 SSH 配置文件
func parseSSHConfig(configPath, targetHost string) (*SSHConfig, error) {
file, err := os.Open(parseShellPath(configPath))
if err != nil {
return nil, fmt.Errorf("open SSH config: %w", err)
}
defer file.Close()
configs := make(map[string]*SSHConfig)
var currentConfig *SSHConfig
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// 跳过注释和空行
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.Fields(line)
if len(parts) < 2 {
continue
}
key := strings.ToLower(parts[0])
value := strings.Join(parts[1:], " ")
switch key {
case "host":
if currentConfig != nil {
configs[currentConfig.Host] = currentConfig
}
currentConfig = &SSHConfig{
Host: value,
Port: "22", // 默认端口
}
case "hostname":
if currentConfig != nil {
currentConfig.HostName = value
}
case "port":
if currentConfig != nil {
currentConfig.Port = value
}
case "user":
if currentConfig != nil {
currentConfig.User = value
}
case "identityfile":
if currentConfig != nil {
// 展开 ~ 为用户主目录
if strings.HasPrefix(value, "~/") {
homeDir, _ := os.UserHomeDir()
value = filepath.Join(homeDir, value[2:])
}
currentConfig.IdentityFile = value
}
case "password":
if currentConfig != nil {
currentConfig.Password = value
}
case "proxyjump":
if currentConfig != nil {
// 支持逗号分隔的多个跳板
jumps := strings.Split(value, ",")
for i := range jumps {
jumps[i] = strings.TrimSpace(jumps[i])
}
currentConfig.ProxyJump = jumps
}
}
}
// 保存最后一个配置
if currentConfig != nil {
configs[currentConfig.Host] = currentConfig
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("read SSH config: %w", err)
}
config, ok := configs[targetHost]
if !ok {
return nil, fmt.Errorf("host %s not found in SSH config", targetHost)
}
return config, nil
}
// 获取当前有效的 SSH 客户端(带自动重连)
func (m *SSHManager) getSSHClient() (*ssh.Client, error) {
m.mu.RLock()
client := m.client
healthy := m.clientHealthy
m.mu.RUnlock()
// 如果客户端存在且健康,先尝试使用
if client != nil && healthy {
// 测试连接是否真的可用
if m.testConnection(client) {
return client, nil
}
// 连接测试失败,标记为不健康并关闭
log.Println("SSH connection test failed, marking as unhealthy")
m.markUnhealthy()
}
// 需要重新连接
return m.reconnect()
}
// 测试 SSH 连接是否真的可用
func (m *SSHManager) testConnection(client *ssh.Client) bool {
// 方法1:尝试发送 keepalive
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
if err != nil {
return false
}
// 方法2(更可靠):尝试打开一个会话
session, err := client.NewSession()
if err != nil {
return false
}
session.Close()
return true
}
// 标记当前连接为不健康并关闭
func (m *SSHManager) markUnhealthy() {
m.mu.Lock()
defer m.mu.Unlock()
if m.client != nil {
m.client.Close()
m.client = nil
}
// 关闭所有跳板机连接
for _, jc := range m.jumpClients {
if jc != nil {
jc.Close()
}
}
m.jumpClients = nil
m.clientHealthy = false
}
func (m *SSHManager) reconnect() (*ssh.Client, error) {
m.mu.Lock()
defer m.mu.Unlock()
// 双重检查:如果已有健康的连接,直接返回
if m.client != nil && m.clientHealthy {
return m.client, nil
}
// 清理旧连接
if m.client != nil {
m.client.Close()
m.client = nil
}
for _, jc := range m.jumpClients {
if jc != nil {
jc.Close()
}
}
m.jumpClients = nil
m.clientHealthy = false
for {
var client *ssh.Client
var err error
if m.sshConfigHost != "" {
log.Printf("Connecting via SSH config host: %s", m.sshConfigHost)
client, err = m.dialSSHWithConfig()
} else {
log.Printf("Connecting to SSH server: %s", m.sshHost)
client, err = m.dialSSH()
}
if err == nil {
m.client = client
m.clientHealthy = true
m.lastError = nil
log.Println("SSH connection established")
go m.keepAlive(client)
return client, nil
}
m.lastError = err
log.Printf("Failed to connect SSH: %v, retrying in %v...", err, m.reconnectSec)
// 在重试期间释放锁,避免阻塞其他请求
m.mu.Unlock()
time.Sleep(m.reconnectSec)
m.mu.Lock()
// 重新获取锁后,检查是否已经有其他协程建立了连接
if m.client != nil && m.clientHealthy {
return m.client, nil
}
}
}
// 通过 SSH 配置文件建立连接(支持多级跳板,使用 SSH -W 模式)
func (m *SSHManager) dialSSHWithConfig() (*ssh.Client, error) {
config, err := parseSSHConfig(m.sshConfigFile, m.sshConfigHost)
if err != nil {
return nil, err
}
// 读取所有跳板的配置
allConfigs, err := m.loadAllSSHConfigs()
if err != nil {
return nil, err
}
// 准备目标服务器的认证方式
targetAuth, err := m.getAuthMethods(config)
if err != nil {
return nil, err
}
targetAddr := config.HostName + ":" + config.Port
targetUser := config.User
if targetUser == "" {
targetUser = m.sshUser
}
// 如果没有跳板,直接连接
if len(config.ProxyJump) == 0 {
log.Printf("Direct connection to %s@%s", targetUser, targetAddr)
return m.dialSSHDirect(targetAddr, targetUser, targetAuth)
}
// 有跳板,使用 SSH -W 模式建立隧道
log.Printf("Connecting through %d jump server(s) using SSH -W mode", len(config.ProxyJump))
return m.dialSSHWithJumps(config, allConfigs, targetUser, targetAddr, targetAuth)
}
// 加载所有 SSH 配置
func (m *SSHManager) loadAllSSHConfigs() (map[string]*SSHConfig, error) {
allConfigs := make(map[string]*SSHConfig)
file, err := os.Open(parseShellPath(m.sshConfigFile))
if err != nil {
return nil, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
var currentConfig *SSHConfig
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.Fields(line)
if len(parts) < 2 {
continue
}
key := strings.ToLower(parts[0])
value := strings.Join(parts[1:], " ")
if key == "host" {
if currentConfig != nil {
allConfigs[currentConfig.Host] = currentConfig
}
currentConfig = &SSHConfig{Host: value, Port: "22"}
} else if currentConfig != nil {
switch key {
case "hostname":
currentConfig.HostName = value
case "port":
currentConfig.Port = value
case "user":
currentConfig.User = value
case "identityfile":
if strings.HasPrefix(value, "~/") {
homeDir, _ := os.UserHomeDir()
value = filepath.Join(homeDir, value[2:])
}
currentConfig.IdentityFile = value
case "password":
currentConfig.Password = value
case "proxyjump":
jumps := strings.Split(value, ",")
for i := range jumps {
jumps[i] = strings.TrimSpace(jumps[i])
}
currentConfig.ProxyJump = jumps
}
}
}
if currentConfig != nil {
allConfigs[currentConfig.Host] = currentConfig
}
return allConfigs, nil
}
// 使用 SSH -W (direct-tcpip) 模式建立多级跳板连接
func (m *SSHManager) dialSSHWithJumps(targetConfig *SSHConfig, allConfigs map[string]*SSHConfig, targetUser, targetAddr string, targetAuth []ssh.AuthMethod) (*ssh.Client, error) {
// 连接第一个跳板
firstJumpName := targetConfig.ProxyJump[0]
firstJumpConfig, ok := allConfigs[firstJumpName]
if !ok {
return nil, fmt.Errorf("jump host %s not found in config", firstJumpName)
}
firstJumpAuth, err := m.getAuthMethods(firstJumpConfig)
if err != nil {
return nil, fmt.Errorf("get auth for first jump %s: %w", firstJumpName, err)
}
firstJumpAddr := firstJumpConfig.HostName + ":" + firstJumpConfig.Port
firstJumpUser := firstJumpConfig.User
if firstJumpUser == "" {
firstJumpUser = m.sshUser
}
log.Printf("Connecting to first jump: %s@%s", firstJumpUser, firstJumpAddr)
currentClient, err := m.dialSSHDirect(firstJumpAddr, firstJumpUser, firstJumpAuth)
if err != nil {
return nil, fmt.Errorf("connect to first jump %s: %w", firstJumpName, err)
}
// 保存跳板机连接用于后续清理
m.jumpClients = append(m.jumpClients, currentClient)
// 如果有多个跳板,依次通过 SSH -W 模式连接
for i := 1; i < len(targetConfig.ProxyJump); i++ {
jumpName := targetConfig.ProxyJump[i]
jumpConfig, ok := allConfigs[jumpName]
if !ok {
m.closeJumpClients()
return nil, fmt.Errorf("jump host %s not found in config", jumpName)
}
jumpAuth, err := m.getAuthMethods(jumpConfig)
if err != nil {
m.closeJumpClients()
return nil, fmt.Errorf("get auth for jump %s: %w", jumpName, err)
}
jumpAddr := jumpConfig.HostName + ":" + jumpConfig.Port
jumpUser := jumpConfig.User
if jumpUser == "" {
jumpUser = m.sshUser
}
log.Printf("Connecting to jump %d via SSH -W: %s@%s", i+1, jumpUser, jumpAddr)
// 使用 SSH -W 模式(direct-tcpip channel)建立隧道
nextClient, err := m.dialSSHThroughDirectTCPIP(currentClient, jumpAddr, jumpUser, jumpAuth)
if err != nil {
m.closeJumpClients()
return nil, fmt.Errorf("connect to jump %s via SSH -W: %w", jumpName, err)
}
m.jumpClients = append(m.jumpClients, nextClient)
currentClient = nextClient
}
// 通过最后一个跳板使用 SSH -W 连接目标服务器
log.Printf("Connecting to target via SSH -W: %s@%s", targetUser, targetAddr)
finalClient, err := m.dialSSHThroughDirectTCPIP(currentClient, targetAddr, targetUser, targetAuth)
if err != nil {
m.closeJumpClients()
return nil, fmt.Errorf("connect to target via SSH -W: %w", err)
}
return finalClient, nil
}
// 关闭所有跳板机连接
func (m *SSHManager) closeJumpClients() {
for _, jc := range m.jumpClients {
if jc != nil {
jc.Close()
}
}
m.jumpClients = nil
}
// 获取认证方式
func (m *SSHManager) getAuthMethods(config *SSHConfig) ([]ssh.AuthMethod, error) {
var auth []ssh.AuthMethod
// 优先使用配置文件中的密码
if config.Password != "" {
auth = append(auth, ssh.Password(config.Password))
} else if m.sshPassword != "" {
// 其次使用命令行参数的密码
auth = append(auth, ssh.Password(m.sshPassword))
}
// 使用配置文件中的私钥或命令行参数的私钥
keyFile := config.IdentityFile
if keyFile == "" && m.sshKeyFile != "" {
keyFile = m.sshKeyFile
}
if keyFile != "" {
keyBytes, err := os.ReadFile(keyFile)
if err != nil {
return nil, fmt.Errorf("read private key file %s: %w", keyFile, err)
}
var signer ssh.Signer
var parseErr error
if *sshKeyPassphrase != "" {
signer, parseErr = ssh.ParsePrivateKeyWithPassphrase(keyBytes, []byte(*sshKeyPassphrase))
if parseErr != nil {
return nil, fmt.Errorf("parse encrypted private key %s: %w", keyFile, parseErr)
}
} else {
signer, parseErr = ssh.ParsePrivateKey(keyBytes)
if parseErr != nil {
// 可能需要密码
return nil, fmt.Errorf("parse private key %s (may need passphrase): %w", keyFile, parseErr)
}
}
auth = append(auth, ssh.PublicKeys(signer))
}
if len(auth) == 0 {
return nil, fmt.Errorf("no authentication method available")
}
return auth, nil
}
// 直接连接 SSH 服务器
func (m *SSHManager) dialSSHDirect(addr, user string, auth []ssh.AuthMethod) (*ssh.Client, error) {
config := &ssh.ClientConfig{
User: user,
Auth: auth,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 10 * time.Second,
}
return ssh.Dial("tcp", addr, config)
}
// 通过已有 SSH 客户端使用 direct-tcpip channel(SSH -W 模式)连接下一跳
func (m *SSHManager) dialSSHThroughDirectTCPIP(jumpClient *ssh.Client, targetAddr, targetUser string, targetAuth []ssh.AuthMethod) (*ssh.Client, error) {
// 使用 direct-tcpip channel 类型,这是 SSH -W 的底层实现
// 这相当于 ssh -W targetHost:targetPort jumpHost
log.Printf("Opening direct-tcpip channel to %s", targetAddr)
conn, err := jumpClient.Dial("tcp", targetAddr)
if err != nil {
return nil, fmt.Errorf("dial through jump host via direct-tcpip: %w", err)
}
// 通过隧道连接建立 SSH 客户端
config := &ssh.ClientConfig{
User: targetUser,
Auth: targetAuth,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 10 * time.Second,
}
ncc, chans, reqs, err := ssh.NewClientConn(conn, targetAddr, config)
if err != nil {
conn.Close()
return nil, fmt.Errorf("SSH handshake through tunnel: %w", err)
}
return ssh.NewClient(ncc, chans, reqs), nil
}
func parseShellPath(path string) string {
if strings.HasPrefix(path, "~/") {
homeDir, err := os.UserHomeDir()
if err != nil {
log.Printf("Failed to get user home directory: %v", err)
return path
}
return filepath.Join(homeDir, path[2:])
}
return path
}
// 原有的直接拨号方法(用于命令行参数方式)
func (m *SSHManager) dialSSH() (*ssh.Client, error) {
var auth []ssh.AuthMethod
if m.sshPassword != "" {
auth = append(auth, ssh.Password(m.sshPassword))
}
if m.sshKeyFile != "" {
keyBytes, err := os.ReadFile(m.sshKeyFile)
if err != nil {
return nil, fmt.Errorf("read private key file: %w", err)
}
var signer ssh.Signer
var parseErr error
if *sshKeyPassphrase != "" {
signer, parseErr = ssh.ParsePrivateKeyWithPassphrase(keyBytes, []byte(*sshKeyPassphrase))
if parseErr == nil {
auth = append(auth, ssh.PublicKeys(signer))
} else {
return nil, fmt.Errorf("failed to parse encrypted private key with given passphrase: %w", parseErr)
}
} else {
signer, parseErr = ssh.ParsePrivateKey(keyBytes)
if parseErr == nil {
auth = append(auth, ssh.PublicKeys(signer))
} else {
return nil, fmt.Errorf("failed to parse private key (it may be encrypted; try --ssh-key-passphrase): %w", parseErr)
}
}
}
if !strings.Contains(m.sshHost, ":") {
m.sshHost = m.sshHost + ":22"
}
config := &ssh.ClientConfig{
User: m.sshUser,
Auth: auth,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 10 * time.Second,
}
return ssh.Dial("tcp", m.sshHost, config)
}
// 保活协程
func (m *SSHManager) keepAlive(client *ssh.Client) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
if err != nil {
log.Printf("SSH keepalive failed: %v", err)
m.mu.Lock()
if m.client == client {
m.clientHealthy = false
m.client.Close()
m.client = nil
// 同时关闭所有跳板机连接
for _, jc := range m.jumpClients {
if jc != nil {
jc.Close()
}
}
m.jumpClients = nil
}
m.mu.Unlock()
return
}
}
}
}
func (p *HTTPProxy) checkAuth(r *http.Request) bool {
if *httpProxyUser == "" && *httpProxyPass == "" {
return true
}
auth := r.Header.Get("Proxy-Authorization")
if auth == "" {
return false
}
if !strings.HasPrefix(auth, "Basic ") {
return false
}
payload, err := base64.StdEncoding.DecodeString(auth[6:])
if err != nil {
return false
}
creds := string(payload)
parts := strings.SplitN(creds, ":", 2)
if len(parts) != 2 {
return false
}
username := parts[0]
password := parts[1]
return username == *httpProxyUser && password == *httpProxyPass
}
// ServeHTTP 实现 HTTP 代理逻辑
func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !p.checkAuth(r) {
w.Header().Set("Proxy-Authenticate", "Basic realm=\"Proxy\"")
http.Error(w, "Proxy authentication required", http.StatusProxyAuthRequired)
return
}
clientIP, _, _ := net.SplitHostPort(r.RemoteAddr)
log.Printf("[HTTP] Client %s -> Target %s %s %s", clientIP, r.URL.Host, r.Method, r.URL.Path)
if r.Method == http.MethodConnect {
p.handleConnect(w, r)
} else {
p.handleHTTP(w, r)
}
}
func (p *HTTPProxy) handleConnect(w http.ResponseWriter, r *http.Request) {
dest := r.Host
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "Hijack not supported", http.StatusInternalServerError)
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
log.Printf("Hijack error: %v", err)
return
}
defer clientConn.Close()
var sshClient *ssh.Client
var targetConn net.Conn
maxRetries := 3
for retry := 0; retry < maxRetries; retry++ {
sshClient, err = p.sshManager.getSSHClient()
if err != nil {
log.Printf("Failed to get SSH client for CONNECT (attempt %d/%d): %v", retry+1, maxRetries, err)
if retry < maxRetries-1 {
time.Sleep(time.Second * time.Duration(retry+1))
continue
}
clientConn.Write([]byte("HTTP/1.1 503 Service Unavailable\r\n\r\n"))
return
}
targetConn, err = sshClient.Dial("tcp", dest)
if err != nil {
log.Printf("Failed to dial target %s via SSH (attempt %d/%d): %v", dest, retry+1, maxRetries, err)
p.sshManager.markUnhealthy()
if retry < maxRetries-1 {
time.Sleep(time.Second * time.Duration(retry+1))
continue
}
clientConn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
return
}
break
}
defer targetConn.Close()
clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(targetConn, clientConn)
targetConn.Close()
}()
go func() {
defer wg.Done()
io.Copy(clientConn, targetConn)
clientConn.Close()
}()
wg.Wait()
}
func (p *HTTPProxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.String(), "http") {
http.Error(w, "URL must be absolute", http.StatusBadRequest)
return
}
maxRetries := 3
var sshClient *ssh.Client
var targetConn net.Conn
var err error
for retry := 0; retry < maxRetries; retry++ {
sshClient, err = p.sshManager.getSSHClient()
if err != nil {
log.Printf("Failed to get SSH client (attempt %d/%d): %v", retry+1, maxRetries, err)
if retry < maxRetries-1 {
time.Sleep(time.Second * time.Duration(retry+1))
continue
}
http.Error(w, "SSH unavailable", http.StatusServiceUnavailable)
return
}
targetConn, err = sshClient.Dial("tcp", r.URL.Host)
if err != nil {
log.Printf("Failed to dial %s (attempt %d/%d): %v", r.URL.Host, retry+1, maxRetries, err)
p.sshManager.markUnhealthy()
if retry < maxRetries-1 {
time.Sleep(time.Second * time.Duration(retry+1))
continue
}
http.Error(w, "Gateway error", http.StatusBadGateway)
return
}
break
}
defer targetConn.Close()
err = r.Write(targetConn)
if err != nil {
log.Printf("Failed to write request: %v", err)
http.Error(w, "Write error", http.StatusBadGateway)
return
}
resp, err := http.ReadResponse(bufio.NewReader(targetConn), r)
if err != nil {
log.Printf("Failed to read response: %v", err)
http.Error(w, "Read error", http.StatusBadGateway)
return
}
defer resp.Body.Close()
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}