提交 774604cd 编写于 作者: D dongzhihong

"add more NewService argument"

上级 8426beb4
...@@ -191,8 +191,8 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { ...@@ -191,8 +191,8 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
// PutKey put into etcd with value by key specified // PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error { func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error {
ctx, err := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
_, err = e.Put(ctx, key, value) _, err := e.etcdClient.Put(ctx, key, string(value))
cancel() cancel()
if err != nil { if err != nil {
return err return err
......
...@@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) { ...@@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
Param: p, Param: p,
Config: config, Config: config,
} }
o := newOptimizer(param) o := newOptimizer(param, nil)
o.Cleanup() o.Cleanup()
} }
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"bytes" "bytes"
"crypto/md5" "crypto/md5"
"encoding/gob" "encoding/gob"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
...@@ -67,30 +68,19 @@ type checkpointMeta struct { ...@@ -67,30 +68,19 @@ type checkpointMeta struct {
type Checkpoint []parameterCheckPoint type Checkpoint []parameterCheckPoint
// Gradient is the gradient of the parameter. // Gradient is the gradient of the parameter.
type Gradient Parameter
// Service is the RPC service for pserver. // Service is the RPC service for pserver.
type Service struct { type Service struct {
initialized chan struct{} initialized chan struct{}
idx int idx int
checkpointInterval int checkpointInterval time.Duration
checkpointPath string checkpointPath string
client *EtcdClient client *EtcdClient
mu sync.Mutex mu sync.Mutex
optMap map[string]*optimizer optMap map[string]*optimizer
} }
// //serialize ParameterWithConfig to byte stream
// func GetBytes(content ...interface{}) ([]byte, error) {
// var buf bytes.Buffer
// encoder := gob.NewEncoder(&buf)
// err := encoder.Encode(content)
// if err != nil {
// return nil, err
// }
// return buf.Bytes(), nil
// }
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified.
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
...@@ -129,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er ...@@ -129,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is // TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory // properly memory aligned, if not, make copy to a memory
// aligned region. // aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs) s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
return nil return nil
} }
...@@ -218,7 +208,7 @@ func (s *Service) doCheckpoint() error { ...@@ -218,7 +208,7 @@ func (s *Service) doCheckpoint() error {
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().String() cpMeta.Timestamp = time.Now().String()
h := md5.New() h := md5.New()
cpMeta.Md5sum = h.Sum(buf.Bytes()) cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMetajson, err := json.Marshal(cpMeta) cpMetajson, err := json.Marshal(cpMeta)
s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
......
...@@ -87,7 +87,8 @@ func TestServiceFull(t *testing.T) { ...@@ -87,7 +87,8 @@ func TestServiceFull(t *testing.T) {
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -103,7 +104,8 @@ func TestMultipleInit(t *testing.T) { ...@@ -103,7 +104,8 @@ func TestMultipleInit(t *testing.T) {
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.FailNow()
...@@ -111,7 +113,8 @@ func TestUninitialized(t *testing.T) { ...@@ -111,7 +113,8 @@ func TestUninitialized(t *testing.T) {
} }
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -129,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -129,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch <- struct{}{} ch <- struct{}{}
}() }()
wg.Add(1)
go func() {
err := s.Save("", nil)
if err != nil {
errCh <- err
}
wg.Done()
ch <- struct{}{}
}()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
select { select {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册