提交 0ad7053e 编写于 作者: D dongzhihong

"make parameterCheckpoint exported"

上级 87e7924e
......@@ -21,7 +21,7 @@ func main() {
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", 10, "save checkpoint per interval seconds")
checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse()
......
......@@ -51,8 +51,8 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format
}
// Checkpoint of Parameter and State
type parameterCheckPoint struct {
// ParameterCheckpoint is Parameter and State checkpoint
type ParameterCheckpoint struct {
ParamConfig ParameterWithConfig
State []byte
}
......@@ -65,7 +65,7 @@ type checkpointMeta struct {
}
// Checkpoint is the pserver shard persist in file
type Checkpoint []parameterCheckPoint
type Checkpoint []ParameterCheckpoint
// Gradient is the gradient of the parameter.
type Gradient Parameter
......@@ -186,10 +186,10 @@ func (s *Service) doCheckpoint() error {
s.mu.Lock()
defer s.mu.Unlock()
cp := make([]parameterCheckPoint, 0, len(s.optMap))
cp := make([]ParameterCheckpoint, 0, len(s.optMap))
index := 0
for name, opt := range s.optMap {
var pc parameterCheckPoint
var pc ParameterCheckpoint
pc.ParamConfig.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights()
......@@ -210,8 +210,8 @@ func (s *Service) doCheckpoint() error {
h := md5.New()
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMetajson, err := json.Marshal(cpMeta)
s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
cpMetajson, _ := json.Marshal(cpMeta)
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
if err != nil {
return err
}
......@@ -224,13 +224,13 @@ func (s *Service) doCheckpoint() error {
f, err := os.Create(cpMeta.UUID)
defer f.Close()
if err != nil {
log.Errorln(err)
return err
}
writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes())
writer.Flush()
if err != nil {
log.Errorln(err)
return err
}
return nil
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册