core/ssh/ssh.go

229 lines
5.1 KiB
Go

package ssh
import (
"bytes"
"fmt"
"io"
"log"
"net"
"os"
"os/user"
"path"
"strings"
"gitea.meta-tech.academy/go/core/echo"
"gitea.meta-tech.academy/go/core/sys"
"gitea.meta-tech.academy/go/core/util"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/knownhosts"
)
type SshAuth struct {
Host string `yaml:"host"`
Type string `yaml:"type"`
Port int `yaml:"port"`
User string `yaml:"user"`
}
type Ssh struct {
Config *SshAuth
Client *ssh.Client
Session *ssh.Session
bar *io.Writer
withBar bool
}
func NewSsh(config *SshAuth) *Ssh {
s := &Ssh{Config: config, withBar: false}
s.InitDial()
return s
}
func (s *Ssh) initSession() {
var err error
s.Session, err = s.Client.NewSession()
if err != nil {
log.Fatal(err)
}
}
func (s *Ssh) AddBar(bar io.Writer) {
s.bar = &bar
s.withBar = true
}
func (s *Ssh) Exec(cmd string, close bool) *bytes.Buffer {
if close {
defer s.Client.Close()
}
s.initSession()
defer s.Session.Close()
var buff bytes.Buffer
s.Session.Stdout = &buff
if err := s.Session.Run(cmd); err != nil {
log.Fatal(err)
}
return &buff
}
func (s *Ssh) getHostKeyCB() ssh.HostKeyCallback {
user, err := user.Current()
if err != nil {
log.Fatalf(err.Error())
}
hkcb, err := knownhosts.New(fmt.Sprintf("/home/%s/.ssh/known_hosts", user.Username))
if err != nil {
panic(err)
}
// return ssh.InsecureIgnoreHostKey()
return hkcb
}
func (s *Ssh) getSigners() []ssh.Signer {
sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
log.Fatal(err)
}
agent := agent.NewClient(sock)
signers, err := agent.Signers()
if err != nil {
log.Fatal(err)
}
return signers
}
func (s *Ssh) InitDial() {
auths := []ssh.AuthMethod{ssh.PublicKeys(s.getSigners()...)}
cfg := &ssh.ClientConfig{
User: s.Config.User,
Auth: auths,
HostKeyCallback: s.getHostKeyCB(), // ssh.InsecureIgnoreHostKey(),
}
cfg.SetDefaults()
var err error
s.Client, err = ssh.Dial(s.Config.Type, fmt.Sprintf("%s:%d", s.Config.Host, s.Config.Port), cfg)
if err != nil {
log.Fatal(err)
}
}
func (s *Ssh) Scp() *sftp.Client {
scp, err := sftp.NewClient(s.Client)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to start SFTP subsystem: %v\n", err)
os.Exit(1)
}
return scp
}
type ApplyOnSize func(s *Ssh, size int64)
func (s *Ssh) DownloadFile(remoteFile string, localFile string, display bool, close bool, fn ...ApplyOnSize) bool {
if display {
echo.Action("Downloading remote file", path.Base(remoteFile))
}
done := false
alreadyDownload := false
checksum := ""
buf := s.Exec(fmt.Sprintf("du --apparent-size --block-size=1 \"%s\" | awk '{ print $1}'", remoteFile), false)
size := util.Str2int64(buf.String(), 10, -1)
buf.Reset()
if len(fn) > 0 {
fn[0](s, size)
}
buf = s.Exec(fmt.Sprintf("sha256sum \"%s\" | cut -d' ' -f1", remoteFile), false)
checksum = strings.TrimSuffix(buf.String(), "\n")
buf.Reset()
if sys.CheckFileSize(size, localFile) {
if display {
echo.Cstyle("usageCom").Echo(" already downloaded\n")
}
alreadyDownload = sys.CheckSumFile(checksum, localFile)
if display {
if alreadyDownload {
echo.Cstyle("usageCom").Echo(" file integrity confirmed\n")
echo.State(alreadyDownload)
done = alreadyDownload
} else {
echo.Cstyle("podFullName").Echo(" file seems corrupt, retry download\n")
}
}
} else if display {
echo.Cstyle("podFullName").Echo(" file seems corrupt, retry download\n")
}
if !alreadyDownload {
scp := s.Scp()
if close {
defer scp.Close()
}
if err := s.downloadFile(size, scp, remoteFile, localFile); err == nil {
done = sys.CheckSumFile(checksum, localFile)
if display {
echo.Cstyle("usageCom").Echo(" file downloaded !\n")
if done {
echo.Cstyle("usageCom").Echo(" file integrity confirmed\n")
} else {
echo.Cstyle("usageCom").Echo(" file seems corrupt, abort\n")
}
echo.State(done)
}
} else {
log.Fatal(err)
}
}
return done
}
func (s *Ssh) ScpDownload(size int64, remoteFile string, localFile string, close bool) bool {
var done bool = false
scp := s.Scp()
if close {
defer scp.Close()
}
if err := s.downloadFile(size, scp, remoteFile, localFile); err == nil {
done = true
} else {
log.Fatal(err)
}
return done
}
func (s *Ssh) downloadFile(size int64, sc *sftp.Client, remoteFile string, localFile string) (err error) {
// Note: SFTP To Go doesn't support O_RDWR mode
srcFile, err := sc.OpenFile(remoteFile, (os.O_RDONLY))
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to open remote file: %v\n", err)
return
}
defer srcFile.Close()
fmt.Println()
dstFile, err := os.Create(localFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to open local file: %v\n", err)
return
}
defer dstFile.Close()
if s.withBar {
_, err = io.Copy(io.MultiWriter(dstFile, *s.bar), srcFile)
} else {
_, err = io.Copy(dstFile, srcFile)
}
// s.bar.Clear()
// s.bar.Finish()
echo.LineUp(1)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to download remote file: %v\n", err)
os.Exit(1)
}
// fmt.Fprintf(os.Stdout, "%d bytes copied\n", bytes)
return
}