core/ssh/ssh.go

167 lines
3.3 KiB
Go

package ssh
import (
"bytes"
"fmt"
"io"
"log"
"net"
"os"
"os/user"
"gitea.meta-tech.academy/go/core/echo"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/knownhosts"
)
type SshAuth struct {
Host string `yaml:"host"`
Type string `yaml:"type"`
Port int `yaml:"port"`
User string `yaml:"user"`
}
type Ssh struct {
Config *SshAuth
Client *ssh.Client
Session *ssh.Session
bar *io.Writer
withBar bool
}
func NewSsh(config *SshAuth) *Ssh {
s := &Ssh{Config: config, withBar: false}
s.InitDial()
return s
}
func (s *Ssh) initSession() {
var err error
s.Session, err = s.Client.NewSession()
if err != nil {
log.Fatal(err)
}
}
func (s *Ssh) AddBar(bar io.Writer) {
s.bar = &bar
s.withBar = true
}
func (s *Ssh) Exec(cmd string, close bool) *bytes.Buffer {
if close {
defer s.Client.Close()
}
s.initSession()
defer s.Session.Close()
var buff bytes.Buffer
s.Session.Stdout = &buff
if err := s.Session.Run(cmd); err != nil {
log.Fatal(err)
}
return &buff
}
func (s *Ssh) getHostKeyCB() ssh.HostKeyCallback {
user, err := user.Current()
if err != nil {
log.Fatalf(err.Error())
}
hkcb, err := knownhosts.New(fmt.Sprintf("/home/%s/.ssh/known_hosts", user.Username))
if err != nil {
panic(err)
}
// return ssh.InsecureIgnoreHostKey()
return hkcb
}
func (s *Ssh) getSigners() []ssh.Signer {
sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
log.Fatal(err)
}
agent := agent.NewClient(sock)
signers, err := agent.Signers()
if err != nil {
log.Fatal(err)
}
return signers
}
func (s *Ssh) InitDial() {
auths := []ssh.AuthMethod{ssh.PublicKeys(s.getSigners()...)}
cfg := &ssh.ClientConfig{
User: s.Config.User,
Auth: auths,
HostKeyCallback: s.getHostKeyCB(), // ssh.InsecureIgnoreHostKey(),
}
cfg.SetDefaults()
var err error
s.Client, err = ssh.Dial(s.Config.Type, fmt.Sprintf("%s:%d", s.Config.Host, s.Config.Port), cfg)
if err != nil {
log.Fatal(err)
}
}
func (s *Ssh) Scp() *sftp.Client {
scp, err := sftp.NewClient(s.Client)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to start SFTP subsystem: %v\n", err)
os.Exit(1)
}
return scp
}
func (s *Ssh) ScpDownload(size int64, remoteFile string, localFile string, close bool) bool {
var done bool = false
scp := s.Scp()
if close {
defer scp.Close()
}
if err := s.downloadFile(size, scp, remoteFile, localFile); err == nil {
done = true
} else {
log.Fatal(err)
}
return done
}
func (s *Ssh) downloadFile(size int64, sc *sftp.Client, remoteFile string, localFile string) (err error) {
// Note: SFTP To Go doesn't support O_RDWR mode
srcFile, err := sc.OpenFile(remoteFile, (os.O_RDONLY))
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to open remote file: %v\n", err)
return
}
defer srcFile.Close()
fmt.Println()
dstFile, err := os.Create(localFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to open local file: %v\n", err)
return
}
defer dstFile.Close()
if s.withBar {
_, err = io.Copy(io.MultiWriter(dstFile, *s.bar), srcFile)
} else {
_, err = io.Copy(dstFile, srcFile)
}
// s.bar.Clear()
// s.bar.Finish()
echo.LineUp(1)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to download remote file: %v\n", err)
os.Exit(1)
}
// fmt.Fprintf(os.Stdout, "%d bytes copied\n", bytes)
return
}