提交 40295b9e 编写于 作者: D dongzhihong

"fix pserver saving etcd"

上级 bfc3b436
......@@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", "10", "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse()
......@@ -31,6 +33,7 @@ func main() {
log.SetLevel(level)
var idx int
var cp pserver.Checkpoint
if *index >= 0 {
idx = *index
} else {
......@@ -42,7 +45,7 @@ func main() {
}
}
s, err := pserver.NewService(idx)
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
if err != nil {
panic(err)
}
......
......@@ -18,6 +18,8 @@ const (
PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr
PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/"
)
// EtcdClient is the etcd client that the pserver uses for fault
......@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return idx, nil
}
// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error {
ctx, err := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
_, err = e.Put(ctx, key, value)
cancel()
if err != nil {
return err
}
return nil
}
......@@ -35,12 +35,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return (*[1 << 30]byte)(p)[:len:len]
}
func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType
p := paramWithConfigs.Param
c := paramWithConfigs.Config
s := paramWithConfigs.State
s := State
log.WithFields(log.Fields{
"ElementType": p.ElementType,
"ParamSize": len(p.Content),
......
......@@ -5,10 +5,11 @@ import (
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"sync"
"time"
......@@ -26,10 +27,6 @@ const (
Uninitialized = "pserver not fully initialized"
)
const (
checkpoint_path = "./checkpoints/"
)
// Supported element types
const (
Int32 ElementType = iota
......@@ -51,49 +48,68 @@ type Parameter struct {
type ParameterWithConfig struct {
Param Parameter
Config []byte // parameter configuration in Proto Buffer format
State []byte // parameter training state
}
// Checkpoint of Parameter and State
type parameterCheckPoint struct {
ParamConfig ParameterWithConfig
State []byte
}
// checkpoint signature
type checkpointMeta struct {
UUID string `json:"uuid"`
Md5sum string `json:"md5sum"`
Timestamp string `json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type Checkpoint []parameterCheckPoint
// Gradient is the gradient of the parameter.
type Gradient Parameter
// Service is the RPC service for pserver.
type Service struct {
initialized chan struct{}
idx int
mu sync.Mutex
optMap map[string]*optimizer
initialized chan struct{}
idx int
checkpointInterval int
checkpointPath string
client *EtcdClient
mu sync.Mutex
optMap map[string]*optimizer
}
type checkpoint struct {
Uuid string
Md5sum string
Timestamp string
}
// //serialize ParameterWithConfig to byte stream
// func GetBytes(content ...interface{}) ([]byte, error) {
//serialize ParameterWithConfig to byte stream
func GetBytes(content ...interface{}) ([]byte, error) {
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err := encoder.Encode(content)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// var buf bytes.Buffer
// encoder := gob.NewEncoder(&buf)
// err := encoder.Encode(content)
// if err != nil {
// return nil, err
// }
// return buf.Bytes(), nil
// }
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func NewService(idx int) (*Service, error) {
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
idx: idx,
checkpointInterval: time.Second * time.Duration(seconds),
checkpointPath: path,
client: client,
}
s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{})
gob.Register(ParameterWithConfig{})
gob.Register(checkpoint{})
if cp != nil {
for _, item := range cp {
p := item.ParamConfig
st := item.State
s.optMap[p.Param.Name] = newOptimizer(p, st)
}
}
return s, nil
}
......@@ -174,53 +190,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return nil
}
// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
//FIXME: checkpoint is only used by pserver
// and has a constant path of */checkpoints/{pserver_idx}*
// pserver save checkpoint
func (s *Service) doCheckpoint() error {
<-s.initialized
s.mu.Lock()
defer s.mu.Unlock()
var paramWithConfig ParameterWithConfig
cp := make([]parameterCheckPoint, 0, len(s.optMap))
index := 0
for name, opt := range s.optMap {
paramWithConfig.Param.Name = name
paramWithConfig.Param.ElementType = opt.elementType
paramWithConfig.Param.Content = opt.GetWeights()
paramWithConfig.State = opt.GetStates()
content, err := GetBytes(paramWithConfig)
if err != nil {
log.Errorln(err)
}
ck := checkpoint{}
h := md5.New()
ck.Md5sum = hex.EncodeToString(h.Sum(content))
ck.Timestamp = time.Now().String()
ck.Uuid = checkpoint_path + strconv.Itoa(s.idx)
ckbytes, err := GetBytes(ck)
if err != nil {
log.Errorln(err)
}
// TODO: according design doc, need to save Uuid to etcd in json format
// {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx}
log.Infof("parameter checkpoint %s", ckbytes)
if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) {
log.Info("checkpoint not exists.")
} else {
err = os.Remove(ck.Uuid)
log.Infof("remove %s", ck.Uuid)
}
f, err := os.Create(ck.Uuid)
defer f.Close()
if err != nil {
log.Errorln(err)
}
writer := bufio.NewWriter(f)
_, err = writer.Write(content)
writer.Flush()
if err != nil {
log.Errorln(err)
}
var pc parameterCheckPoint
pc.ParamConfig.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights()
pc.State = opt.GetStates()
cp[index] = pc
index++
}
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err := encoder.Encode(cp)
if err != nil {
return err
}
cpMeta := checkpointMeta{}
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().String()
h := md5.New()
cpMeta.Md5sum = h.Sum(buf.Bytes())
cpMetajson, err := json.Marshal(cpMeta)
s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
if err != nil {
return err
}
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
log.Info("checkpoint does not exists.")
} else {
err = os.Remove(cpMeta.UUID)
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
f, err := os.Create(cpMeta.UUID)
defer f.Close()
if err != nil {
log.Errorln(err)
}
writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes())
writer.Flush()
if err != nil {
log.Errorln(err)
}
return nil
}
......@@ -15,7 +15,8 @@ const (
)
func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
t.Error(err)
}
......@@ -83,8 +84,6 @@ func TestServiceFull(t *testing.T) {
if !reflect.DeepEqual(param1, p) {
t.FailNow()
}
var dummy int
s.Save("", &dummy)
}
func TestMultipleInit(t *testing.T) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册