提交 0ad7053e 编写于 作者: D dongzhihong

"make parameterCheckpoint exported"

上级 87e7924e
...@@ -21,7 +21,7 @@ func main() { ...@@ -21,7 +21,7 @@ func main() {
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", 10, "save checkpoint per interval seconds") checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
......
...@@ -51,8 +51,8 @@ type ParameterWithConfig struct { ...@@ -51,8 +51,8 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format Config []byte // parameter configuration in Proto Buffer format
} }
// Checkpoint of Parameter and State // ParameterCheckpoint is Parameter and State checkpoint
type parameterCheckPoint struct { type ParameterCheckpoint struct {
ParamConfig ParameterWithConfig ParamConfig ParameterWithConfig
State []byte State []byte
} }
...@@ -65,7 +65,7 @@ type checkpointMeta struct { ...@@ -65,7 +65,7 @@ type checkpointMeta struct {
} }
// Checkpoint is the pserver shard persist in file // Checkpoint is the pserver shard persist in file
type Checkpoint []parameterCheckPoint type Checkpoint []ParameterCheckpoint
// Gradient is the gradient of the parameter. // Gradient is the gradient of the parameter.
type Gradient Parameter type Gradient Parameter
...@@ -186,10 +186,10 @@ func (s *Service) doCheckpoint() error { ...@@ -186,10 +186,10 @@ func (s *Service) doCheckpoint() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
cp := make([]parameterCheckPoint, 0, len(s.optMap)) cp := make([]ParameterCheckpoint, 0, len(s.optMap))
index := 0 index := 0
for name, opt := range s.optMap { for name, opt := range s.optMap {
var pc parameterCheckPoint var pc ParameterCheckpoint
pc.ParamConfig.Param.Name = name pc.ParamConfig.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType pc.ParamConfig.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights() pc.ParamConfig.Param.Content = opt.GetWeights()
...@@ -210,8 +210,8 @@ func (s *Service) doCheckpoint() error { ...@@ -210,8 +210,8 @@ func (s *Service) doCheckpoint() error {
h := md5.New() h := md5.New()
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes())) cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMetajson, err := json.Marshal(cpMeta) cpMetajson, _ := json.Marshal(cpMeta)
s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
if err != nil { if err != nil {
return err return err
} }
...@@ -224,13 +224,13 @@ func (s *Service) doCheckpoint() error { ...@@ -224,13 +224,13 @@ func (s *Service) doCheckpoint() error {
f, err := os.Create(cpMeta.UUID) f, err := os.Create(cpMeta.UUID)
defer f.Close() defer f.Close()
if err != nil { if err != nil {
log.Errorln(err) return err
} }
writer := bufio.NewWriter(f) writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes()) _, err = writer.Write(buf.Bytes())
writer.Flush() writer.Flush()
if err != nil { if err != nil {
log.Errorln(err) return err
} }
return nil return nil
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册