diff --git a/crypt/rsa.go b/crypt/rsa.go new file mode 100644 index 0000000..5e5b670 --- /dev/null +++ b/crypt/rsa.go @@ -0,0 +1,121 @@ +package crypt + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "log" + + "golang.org/x/crypto/ssh" +) + +type RsaEncrypt struct { + pub *rsa.PublicKey +} + +func NewRsaEncrypt(key []byte) *RsaEncrypt { + pkey, _, _, _, err := ssh.ParseAuthorizedKey(key) + if err != nil { + log.Fatal("unable to parse authorized key") + } + // upgrade first to ssh.CryptoPublicKey interface + // then call CryptoPublicKey() to get actual crypto.PublicKey + // Finally, convert back to an *rsa.PublicKey + pubCrypto := pkey.(ssh.CryptoPublicKey).CryptoPublicKey() + return &RsaEncrypt{pubCrypto.(*rsa.PublicKey)} +} + +func (re *RsaEncrypt) Encrypt(data []byte) (string, error) { + encryptedBytes, err := rsa.EncryptOAEP( + sha256.New(), + rand.Reader, + re.pub, + data, + nil) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(encryptedBytes), nil +} + +type RsaKey struct { + priv *rsa.PrivateKey +} + +func NewRsaKey(size int) *RsaKey { + key, err := rsa.GenerateKey(rand.Reader, size) + if err != nil { + log.Fatalf("unable to generate key with size %d", size) + } + return &RsaKey{key} +} + +func LoadRsaKey(privKey []byte) *RsaKey { + block, _ := pem.Decode(privKey) + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + key.Size() + if err != nil { + log.Fatal("unable to parse pkcs1 priv key") + } + return &RsaKey{priv: key} +} + +func (rk *RsaKey) GetBytes() []byte { + return pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(rk.priv), + }) +} + +func (rk *RsaKey) GetPubKeyBytes() []byte { + pub, err := ssh.NewPublicKey(rk.priv.Public()) + if err != nil { + log.Fatal("unable to retriew public key") + } + return ssh.MarshalAuthorizedKey(pub) +} + +func (rk *RsaKey) GetRsaEncrypt() *RsaEncrypt { + return &RsaEncrypt{&rk.priv.PublicKey} +} + +func (rk *RsaKey) Decrypt(b64data []byte) ([]byte, error) { + data, err := base64.StdEncoding.DecodeString(string(b64data)) + if err != nil { + log.Fatal("unable to decode base64 data") + } + decrypted, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, rk.priv, data, nil) + if err != nil { + log.Fatal("unable to decrypt data") + } + return decrypted, nil +} + +func mainSshCrypt() { + rk := NewRsaKey(4096) + pubKey := rk.GetPubKeyBytes() + fmt.Println("== PUB KEY ==") + fmt.Println(string(pubKey)) + fmt.Println("== PRIV KEY ==") + fmt.Println(string(rk.GetBytes())) + + rk2 := LoadRsaKey(rk.GetBytes()) + fmt.Println("== LOADED KEY ==") + fmt.Println(string(rk2.GetBytes())) + + // fmt.Println("== GET RsaEncrypt ==") + // re := rk.GetRsaEncrypt() + fmt.Println("== NewRsaEncrypt ==") + re := NewRsaEncrypt(pubKey) + + fmt.Println("== ENCRYPT data ==") + encryptedData, _ := re.Encrypt([]byte("hello world")) + fmt.Println(encryptedData) + + decryptedData, _ := rk.Decrypt([]byte(encryptedData)) + fmt.Println("== DECRYPT data ==") + fmt.Println(string(decryptedData)) +} diff --git a/ssh/ssh.go b/ssh/ssh.go index 39ca35f..a626b6e 100644 --- a/ssh/ssh.go +++ b/ssh/ssh.go @@ -159,7 +159,7 @@ func (s *Ssh) DownloadFile(remoteFile string, localFile string, display bool, cl if close { defer scp.Close() } - if err := s.downloadFile(size, scp, remoteFile, localFile); err == nil { + if err := s.downloadFile(scp, remoteFile, localFile); err == nil { done = sys.CheckSumFile(checksum, localFile) if display { echo.Cstyle("usageCom").Echo(" file downloaded !\n") @@ -177,13 +177,13 @@ func (s *Ssh) DownloadFile(remoteFile string, localFile string, display bool, cl return done } -func (s *Ssh) ScpDownload(size int64, remoteFile string, localFile string, close bool) bool { +func (s *Ssh) ScpDownload(remoteFile string, localFile string, close bool) bool { var done bool = false scp := s.Scp() if close { defer scp.Close() } - if err := s.downloadFile(size, scp, remoteFile, localFile); err == nil { + if err := s.downloadFile(scp, remoteFile, localFile); err == nil { done = true } else { log.Fatal(err) @@ -191,7 +191,7 @@ func (s *Ssh) ScpDownload(size int64, remoteFile string, localFile string, close return done } -func (s *Ssh) downloadFile(size int64, sc *sftp.Client, remoteFile string, localFile string) (err error) { +func (s *Ssh) downloadFile(sc *sftp.Client, remoteFile string, localFile string) (err error) { // Note: SFTP To Go doesn't support O_RDWR mode srcFile, err := sc.OpenFile(remoteFile, (os.O_RDONLY))