diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 20041d04d089a08a513b607fa887d45f9f837ebb..1f77787150d16052e3588e9c1795c8d5dafa08e6 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -191,8 +191,8 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { // PutKey put into etcd with value by key specified func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error { - ctx, err := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) - _, err = e.Put(ctx, key, value) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) + _, err := e.etcdClient.Put(ctx, key, string(value)) cancel() if err != nil { return err diff --git a/go/pserver/optimizer_test.go b/go/pserver/optimizer_test.go index 0b2f4cfa41a630645c128ac13826de9d8b1d521b..d19e9de92e0b33b1d9619adb615a24884097a38f 100644 --- a/go/pserver/optimizer_test.go +++ b/go/pserver/optimizer_test.go @@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) { Param: p, Config: config, } - o := newOptimizer(param) + o := newOptimizer(param, nil) o.Cleanup() } diff --git a/go/pserver/service.go b/go/pserver/service.go index f27feb247d84a0a171bb5115d7f949a6edbb8b41..cb3741af7ab95a1c2084a3fe34a9607309dbd017 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -5,6 +5,7 @@ import ( "bytes" "crypto/md5" "encoding/gob" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -67,30 +68,19 @@ type checkpointMeta struct { type Checkpoint []parameterCheckPoint // Gradient is the gradient of the parameter. +type Gradient Parameter // Service is the RPC service for pserver. type Service struct { initialized chan struct{} idx int - checkpointInterval int + checkpointInterval time.Duration checkpointPath string client *EtcdClient mu sync.Mutex optMap map[string]*optimizer } -// //serialize ParameterWithConfig to byte stream -// func GetBytes(content ...interface{}) ([]byte, error) { - -// var buf bytes.Buffer -// encoder := gob.NewEncoder(&buf) -// err := encoder.Encode(content) -// if err != nil { -// return nil, err -// } -// return buf.Bytes(), nil -// } - // NewService creates a new service, will bypass etcd registration if no // endpoints specified. func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { @@ -129,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er // TODO(helin): check if paramWithConfigs.Param.Content is // properly memory aligned, if not, make copy to a memory // aligned region. - s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs) + s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil) return nil } @@ -218,7 +208,7 @@ func (s *Service) doCheckpoint() error { cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) cpMeta.Timestamp = time.Now().String() h := md5.New() - cpMeta.Md5sum = h.Sum(buf.Bytes()) + cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes())) cpMetajson, err := json.Marshal(cpMeta) s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index 75d4732ea7862a91126bdd45c2c6d96af049ed54..f365a4539a2d0489fb72b3f2ad1a478bb8d1f727 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -87,7 +87,8 @@ func TestServiceFull(t *testing.T) { } func TestMultipleInit(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) if err != nil { t.Error(err) } @@ -103,7 +104,8 @@ func TestMultipleInit(t *testing.T) { } func TestUninitialized(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) err = s.SendGrad(pserver.Gradient{}, nil) if err.Error() != pserver.Uninitialized { t.FailNow() @@ -111,7 +113,8 @@ func TestUninitialized(t *testing.T) { } func TestBlockUntilInitialized(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) if err != nil { t.Error(err) } @@ -129,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) { ch <- struct{}{} }() - wg.Add(1) - go func() { - err := s.Save("", nil) - if err != nil { - errCh <- err - } - wg.Done() - ch <- struct{}{} - }() - time.Sleep(50 * time.Millisecond) select {