提交 f1330e21 编写于 作者: D dongzhihong

"saving checkpoint"

上级 5ef1425a
package pserver package pserver
import ( import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"os"
"strconv"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
...@@ -14,6 +24,10 @@ const ( ...@@ -14,6 +24,10 @@ const (
Uninitialized = "pserver not fully initialized" Uninitialized = "pserver not fully initialized"
) )
const (
checkpoint_path = "/checkpoints/"
)
// Supported element types // Supported element types
const ( const (
Int32 ElementType = iota Int32 ElementType = iota
...@@ -53,6 +67,24 @@ type Service struct { ...@@ -53,6 +67,24 @@ type Service struct {
optMap map[string]*optimizer optMap map[string]*optimizer
} }
type Checkpoint struct {
uuid string
md5sum string
timestamp string
}
//serialize ParameterWithConfig to byte stream
func GetBytes(content ...interface{}) ([]byte, error) {
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err := encoder.Encode(content)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified.
func NewService(idx int) (*Service, error) { func NewService(idx int) (*Service, error) {
...@@ -143,13 +175,50 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -143,13 +175,50 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
// Save tells the parameter server to save parameters. // Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error { func (s *Service) Save(path string, dummy *int) error {
//FIXME: checkpoint is only used by pserver
// and has a constant path of */checkpoints/{pserver_idx}*
<-s.initialized <-s.initialized
for opt, ok := range s.optMap { s.mu.Lock()
if ok != nil { defer s.mu.Unlock()
return fmt.Errorf("parameter optimizerMap error: ", ok) var paramWithConfig ParameterWithConfig
for name, opt := range s.optMap {
paramWithConfig.Param.Name = name
paramWithConfig.Param.ElementType = opt.elementType
paramWithConfig.Param.Content = opt.GetWeights()
paramWithConfig.State = opt.GetStates()
content, err := GetBytes(paramWithConfig)
if err != nil {
log.Errorln(err)
}
ck := Checkpoint{}
h := md5.New()
ck.md5sum = hex.EncodeToString(h.Sum(content))
ck.timestamp = time.Now().String()
ck.uuid = checkpoint_path + strconv.Itoa(s.idx)
ckbytes, err := GetBytes(ck)
if err != nil {
log.Errorln(err)
}
// TODO: according design doc, need to save uuid to etcd in json format
// {\"uuid\": [UUID], \"md5\", \"MD5 sum\", \"timestamp\": xxxx}
log.Infof("parameter checkpoint %s", ckbytes)
if _, err = os.Stat(ck.uuid); os.IsNotExist(err) {
log.Info("checkpoint not exists.")
} else {
err = os.Remove(ck.uuid)
log.Infof("remove %s", ck.uuid)
}
f, err := os.Create(ck.uuid)
defer f.Close()
if err != nil {
log.Errorln(err)
}
writer := bufio.NewWriter(f)
_, err = writer.Write(content)
if err != nil {
log.Errorln(err)
} }
state := opt.GetStates()
weights := opt.GetWeights()
} }
return nil return nil
} }
...@@ -79,6 +79,8 @@ func TestServiceFull(t *testing.T) { ...@@ -79,6 +79,8 @@ func TestServiceFull(t *testing.T) {
if !reflect.DeepEqual(param1, p) { if !reflect.DeepEqual(param1, p) {
t.FailNow() t.FailNow()
} }
var dummy int
s.Save("", &dummy)
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
...@@ -166,3 +168,7 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -166,3 +168,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestCheckpointSpeed(t *testing.T) {
//TODO: test speed
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册