sftp.go 2.3 KB
Newer Older
Y
Your Name 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
package utils

import (
	"fmt"
	"os"
	"path"

	"github.com/pkg/sftp"

	"golang.org/x/crypto/ssh"
)

//SftpClient sftp客户端
type SftpClient struct {
	client    *sftp.Client
	passSize  int64
	totalSize int64
	finish    chan error
}

//NewSftpClient new sftp客户端
func NewSftpClient(client *ssh.Client) (*SftpClient, error) {
	sc, err := SftpConnect(client)
	if err != nil {
		return nil, err
	}
	return &SftpClient{
		client:    sc,
		passSize:  0,
		totalSize: 0,
		finish:    make(chan error),
	}, nil
}

//ScpCopy scp复制
func (s *SftpClient) ScpCopy(localFilePath, remoteDir string) error {
	var (
		err error
	)

	srcFile, err := os.Open(localFilePath)
	if err != nil {
		s.finish <- err
		return err
	}
	defer srcFile.Close()
	fInfo, err := srcFile.Stat()
	if err != nil {
		s.finish <- err
		return err
	}
	s.totalSize = fInfo.Size()

	var remoteFileName = path.Base(localFilePath)
	tmpFile := path.Join(remoteDir, fmt.Sprintf("%s%s", remoteFileName, ".tmp"))

	dstFile, err := s.client.Create(tmpFile)
	if err != nil {
		s.finish <- err
		return err
	}
	defer dstFile.Close()
	fmt.Println(dstFile.ReadFrom(srcFile))

	_, err = dstFile.ReadFrom(srcFile)
	if err != nil {
		s.finish <- err
		return err
	}
	s.client.Rename(tmpFile, path.Join(remoteDir, remoteFileName))
	s.finish <- nil

	return nil
}

//GetProcess GetProcess
func (s *SftpClient) GetProcess() string {
	if s.totalSize == 0 {
		return "0.00"
	}
	return fmt.Sprintf("%.2f", float64(s.passSize)*100/float64(s.totalSize))
}

//CheckPathIsExisted CheckPathIsExisted
func (s *SftpClient) CheckPathIsExisted(path string) error {
	_, err := s.client.Stat(path)
	return err
}

//Finish Finish
func (s *SftpClient) Finish() error {
	//period := time.Duration(5) * time.Second
	//t := time.NewTicker(period)
	//for {
	//	select {
	//	case <-t.C:
	//		{
	//			fmt.Println(s.GetProcess())
	//		}
	//	case err := <-s.finish:
	//		{
	//			fmt.Println("100.00")
	//			return err
	//		}
	//	}
	//}
	return <-s.finish
}

//Close Close
func (s *SftpClient) Close() error {
	close(s.finish)
	return s.client.Close()
}

//SftpConnect SftpConnect
func SftpConnect(client *ssh.Client) (*sftp.Client, error) {
	var (
		sftpClient *sftp.Client
		err        error
	)
	// create sftp client
	if sftpClient, err = sftp.NewClient(client); err != nil {
		return nil, err
	}

	return sftpClient, nil
}