go get golang.org/x/crypto/ssh
日志部分每个人的需求不一样这里不赘述
下列代码会出现以下方法
func LogInfo(msg string) //记录信息
func LogError(msg string) //记录错误
func Logf(format string, args ...any) //格式化日志
type Result struct {
IP string
ConnectSuccess bool
CommandSuccess bool
Output string
Error error
}
这里只支持使用账号密码连接,不支持sshkey
func Connect(ip, user, password string, port int, command string) Result {
addr := fmt.Sprintf("%s:%d", ip, port)
config := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
ssh.Password(password),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second, //5秒连接超时
}
client, err := ssh.Dial("tcp", addr, config)
if err != nil {
return Result{IP: ip, ConnectSuccess: false, Error: err}
}
defer client.Close()
//不需要执行命令时只返回IP和是否连接成功,需要执行命令时则额外返回命令的结果和执行命令成功与否
if command == "" {
return Result{IP: ip, ConnectSuccess: true}
}
session, err := client.NewSession()
if err != nil {
return Result{IP: ip, ConnectSuccess: false, Error: err}
}
defer session.Close()
output, err := session.CombinedOutput(command)
if err != nil {
return Result{IP: ip, ConnectSuccess: true, CommandSuccess: false, Error: err, Output: string(output)}
}
return Result{
IP: ip,
ConnectSuccess: true,
CommandSuccess: true,
Output: string(output),
}
}
由于是批量连接主机,ip和账号密码各不相同,所以需要读取文件,这里以对同文件夹下的hosts.ini文件进行读取,文件示例:
172.16.0.100 root 123456 22
172.16.0.101 root 1132123 2222
172.16.0.101 root 112213512
没有端口的默认是22 有端口的执行文件中的端口
type Host struct {
IP string
User string
Password string
Port int
}
func LoadHostFile(filePath string) ([]Host, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("未找到hosts.ini文件,请检查")
}
defer file.Close()
var hosts []Host
lineNumber := 0
scanner := bufio.NewScanner(file)
for scanner.Scan() {
lineNumber++
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
if strings.HasPrefix(line, "#") {
continue
}
//读取错误时要告诉用户是第几行出错,方便修改
fields := strings.Fields(line)
if len(fields) < 3 {
return nil, fmt.Errorf("hosts文件 %d 行格式错误,请检查", line)
}
//端口不填默认22 提前转换为int
port, err := strconv.Atoi(fields[3])
if err != nil {
port = 22
}
host := Host{
IP: fields[0],
User: fields[1],
Password: fields[2],
Port: port,
}
hosts = append(hosts, host)
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("读取hosts文件失败 :%v", err)
}
if len(hosts) == 0 {
return nil, fmt.Errorf("hosts文件没有有效主机")
}
return hosts, nil
}
最终的执行是通过命令行执行,如
batchssh -cmd "ls -la" -w 20
所以main函数如下:
package main
import (
"flag"
"fmt"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
func main() {
cmdLine := strings.Join(os.Args, " ")
command := flag.String("cmd", "", "Command")
//最大连接数为20 超过了就等待
workers := flag.Int("w", 20, "Max Connections")
//解析命令行参数 如果不调用Parse()方法则不会解析并返回结果
flag.Parse()
hosts, _ := LoadHostFile("hosts.ini")
total := int32(len(hosts))
var completed int32
var wg sync.WaitGroup
var progressWg sync.WaitGroup
//用routine和管道来限制同时执行的数量
resultChannel := make(chan Result, len(hosts))
sem := make(chan struct{}, *workers)
progressWg.Go(func() {
for {
done := atomic.LoadInt32(&completed)
printProgress(done, total)
if done >= total {
fmt.Println()
return
}
time.Sleep(200 * time.Millisecond)
}
})
for _, host := range hosts {
wg.Add(1)
go func(h Host) {
defer wg.Done()
sem <- struct{}{}
result := Connect(h.IP, h.User, h.Password, h.Port, *command)
resultChannel <- result
<-sem
atomic.AddInt32(&completed, 1)
}(host)
}
wg.Wait()
close(resultChannel)
progressWg.Wait()
// 统计结果
var successCount int
var failIPs []string
for r := range resultChannel {
if r.ConnectSuccess {
successCount++
Logf("Connecting %s Succeed\n", r.IP)
LogInfo("-------- Command Output --------")
LogInfo(r.Output)
} else {
failIPs = append(failIPs, r.IP)
}
}
LogInfo("-------- Summary --------")
Logf("Total: %d\n", total)
Logf("Success: %d\n", successCount)
Logf("Failed: %d\n", total-int32(successCount))
if len(failIPs) > 0 {
LogInfo("Failed IPs:")
for _, ip := range failIPs {
LogInfo(ip)
}
}
LogInfo("")
LogInfo("")
LogInfo("")
}
// 打印一个进度条
func printProgress(done, total int32) {
percent := float64(done) / float64(total)
barWidth := 40
filled := int(percent * float64(barWidth))
var bar strings.Builder
for i := range barWidth {
if i < filled {
bar.WriteString("=")
} else {
bar.WriteString(" ")
}
}
fmt.Printf("\r[%s] %d/%d", bar.String(), done, total)
}