提交 40295b9e 编写于 作者: D dongzhihong

"fix pserver saving etcd"

上级 bfc3b436
...@@ -20,6 +20,8 @@ func main() { ...@@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", "10", "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
...@@ -31,6 +33,7 @@ func main() { ...@@ -31,6 +33,7 @@ func main() {
log.SetLevel(level) log.SetLevel(level)
var idx int var idx int
var cp pserver.Checkpoint
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
} else { } else {
...@@ -42,7 +45,7 @@ func main() { ...@@ -42,7 +45,7 @@ func main() {
} }
} }
s, err := pserver.NewService(idx) s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
...@@ -18,6 +18,8 @@ const ( ...@@ -18,6 +18,8 @@ const (
PsDesired = "/ps_desired" PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr // PsAddr is the base dir for pserver to store their addr
PsPath = "/ps/" PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/"
) )
// EtcdClient is the etcd client that the pserver uses for fault // EtcdClient is the etcd client that the pserver uses for fault
...@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { ...@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return idx, nil return idx, nil
} }
// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error {
ctx, err := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
_, err = e.Put(ctx, key, value)
cancel()
if err != nil {
return err
}
return nil
}
...@@ -35,12 +35,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { ...@@ -35,12 +35,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return (*[1 << 30]byte)(p)[:len:len] return (*[1 << 30]byte)(p)[:len:len]
} }
func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{} o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType o.elementType = paramWithConfigs.Param.ElementType
p := paramWithConfigs.Param p := paramWithConfigs.Param
c := paramWithConfigs.Config c := paramWithConfigs.Config
s := paramWithConfigs.State s := State
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"ElementType": p.ElementType, "ElementType": p.ElementType,
"ParamSize": len(p.Content), "ParamSize": len(p.Content),
......
...@@ -5,10 +5,11 @@ import ( ...@@ -5,10 +5,11 @@ import (
"bytes" "bytes"
"crypto/md5" "crypto/md5"
"encoding/gob" "encoding/gob"
"encoding/hex" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"os" "os"
"path/filepath"
"strconv" "strconv"
"sync" "sync"
"time" "time"
...@@ -26,10 +27,6 @@ const ( ...@@ -26,10 +27,6 @@ const (
Uninitialized = "pserver not fully initialized" Uninitialized = "pserver not fully initialized"
) )
const (
checkpoint_path = "./checkpoints/"
)
// Supported element types // Supported element types
const ( const (
Int32 ElementType = iota Int32 ElementType = iota
...@@ -51,49 +48,68 @@ type Parameter struct { ...@@ -51,49 +48,68 @@ type Parameter struct {
type ParameterWithConfig struct { type ParameterWithConfig struct {
Param Parameter Param Parameter
Config []byte // parameter configuration in Proto Buffer format Config []byte // parameter configuration in Proto Buffer format
State []byte // parameter training state
} }
// Checkpoint of Parameter and State
type parameterCheckPoint struct {
ParamConfig ParameterWithConfig
State []byte
}
// checkpoint signature
type checkpointMeta struct {
UUID string `json:"uuid"`
Md5sum string `json:"md5sum"`
Timestamp string `json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type Checkpoint []parameterCheckPoint
// Gradient is the gradient of the parameter. // Gradient is the gradient of the parameter.
type Gradient Parameter
// Service is the RPC service for pserver. // Service is the RPC service for pserver.
type Service struct { type Service struct {
initialized chan struct{} initialized chan struct{}
idx int idx int
checkpointInterval int
mu sync.Mutex checkpointPath string
optMap map[string]*optimizer client *EtcdClient
mu sync.Mutex
optMap map[string]*optimizer
} }
type checkpoint struct { // //serialize ParameterWithConfig to byte stream
Uuid string // func GetBytes(content ...interface{}) ([]byte, error) {
Md5sum string
Timestamp string
}
//serialize ParameterWithConfig to byte stream // var buf bytes.Buffer
func GetBytes(content ...interface{}) ([]byte, error) { // encoder := gob.NewEncoder(&buf)
// err := encoder.Encode(content)
var buf bytes.Buffer // if err != nil {
encoder := gob.NewEncoder(&buf) // return nil, err
err := encoder.Encode(content) // }
if err != nil { // return buf.Bytes(), nil
return nil, err // }
}
return buf.Bytes(), nil
}
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified.
func NewService(idx int) (*Service, error) { func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
checkpointInterval: time.Second * time.Duration(seconds),
checkpointPath: path,
client: client,
} }
s.optMap = make(map[string]*optimizer) s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
gob.Register(ParameterWithConfig{})
gob.Register(checkpoint{}) if cp != nil {
for _, item := range cp {
p := item.ParamConfig
st := item.State
s.optMap[p.Param.Name] = newOptimizer(p, st)
}
}
return s, nil return s, nil
} }
...@@ -174,53 +190,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -174,53 +190,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return nil return nil
} }
// Save tells the parameter server to save parameters. // pserver save checkpoint
func (s *Service) Save(path string, dummy *int) error { func (s *Service) doCheckpoint() error {
//FIXME: checkpoint is only used by pserver
// and has a constant path of */checkpoints/{pserver_idx}*
<-s.initialized <-s.initialized
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
var paramWithConfig ParameterWithConfig
cp := make([]parameterCheckPoint, 0, len(s.optMap))
index := 0
for name, opt := range s.optMap { for name, opt := range s.optMap {
paramWithConfig.Param.Name = name var pc parameterCheckPoint
paramWithConfig.Param.ElementType = opt.elementType pc.ParamConfig.Param.Name = name
paramWithConfig.Param.Content = opt.GetWeights() pc.ParamConfig.Param.ElementType = opt.elementType
paramWithConfig.State = opt.GetStates() pc.ParamConfig.Param.Content = opt.GetWeights()
content, err := GetBytes(paramWithConfig) pc.State = opt.GetStates()
if err != nil { cp[index] = pc
log.Errorln(err) index++
} }
ck := checkpoint{} var buf bytes.Buffer
h := md5.New() encoder := gob.NewEncoder(&buf)
ck.Md5sum = hex.EncodeToString(h.Sum(content)) err := encoder.Encode(cp)
ck.Timestamp = time.Now().String() if err != nil {
ck.Uuid = checkpoint_path + strconv.Itoa(s.idx) return err
ckbytes, err := GetBytes(ck) }
if err != nil {
log.Errorln(err) cpMeta := checkpointMeta{}
} cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
// TODO: according design doc, need to save Uuid to etcd in json format cpMeta.Timestamp = time.Now().String()
// {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx} h := md5.New()
log.Infof("parameter checkpoint %s", ckbytes) cpMeta.Md5sum = h.Sum(buf.Bytes())
if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) { cpMetajson, err := json.Marshal(cpMeta)
log.Info("checkpoint not exists.") s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
} else { if err != nil {
err = os.Remove(ck.Uuid) return err
log.Infof("remove %s", ck.Uuid) }
} if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
f, err := os.Create(ck.Uuid) log.Info("checkpoint does not exists.")
defer f.Close() } else {
if err != nil { err = os.Remove(cpMeta.UUID)
log.Errorln(err) log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
} }
writer := bufio.NewWriter(f) f, err := os.Create(cpMeta.UUID)
_, err = writer.Write(content) defer f.Close()
writer.Flush() if err != nil {
if err != nil { log.Errorln(err)
log.Errorln(err) }
} writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes())
writer.Flush()
if err != nil {
log.Errorln(err)
} }
return nil return nil
} }
...@@ -15,7 +15,8 @@ const ( ...@@ -15,7 +15,8 @@ const (
) )
func TestServiceFull(t *testing.T) { func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -83,8 +84,6 @@ func TestServiceFull(t *testing.T) { ...@@ -83,8 +84,6 @@ func TestServiceFull(t *testing.T) {
if !reflect.DeepEqual(param1, p) { if !reflect.DeepEqual(param1, p) {
t.FailNow() t.FailNow()
} }
var dummy int
s.Save("", &dummy)
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册