提交 27fdccc3 编写于 作者: H Helin Wang

fix according to comments

上级 9920a06c
...@@ -9,8 +9,7 @@ import ( ...@@ -9,8 +9,7 @@ import (
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
type ElementType int type ElementType int
var ErrUnintialized = errors.New("pserver not initialized") var ErrAlreadyInitialized = errors.New("pserver already initialized")
var ErrAlreadyIntialized = errors.New("pserver already initialized")
// Supported element types // Supported element types
const ( const (
...@@ -56,7 +55,7 @@ func NewService() *Service { ...@@ -56,7 +55,7 @@ func NewService() *Service {
func (s *Service) BeginInitParams(config []byte, dummy *int) error { func (s *Service) BeginInitParams(config []byte, dummy *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return ErrAlreadyIntialized return ErrAlreadyInitialized
default: default:
} }
...@@ -75,7 +74,7 @@ func (s *Service) BeginInitParams(config []byte, dummy *int) error { ...@@ -75,7 +74,7 @@ func (s *Service) BeginInitParams(config []byte, dummy *int) error {
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return ErrAlreadyIntialized return ErrAlreadyInitialized
default: default:
} }
...@@ -94,7 +93,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er ...@@ -94,7 +93,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return ErrAlreadyIntialized return ErrAlreadyInitialized
default: default:
} }
...@@ -103,11 +102,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { ...@@ -103,11 +102,7 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
} }
func (s *Service) SendGrads(grads []Gradient, dummy *int) error { func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
select { <-s.initialized
case <-s.initialized:
default:
return ErrUnintialized
}
count := len(grads) count := len(grads)
if count == 0 { if count == 0 {
......
...@@ -98,21 +98,12 @@ func TestMultipleInit(t *testing.T) { ...@@ -98,21 +98,12 @@ func TestMultipleInit(t *testing.T) {
} }
err = s.FinishInitParams(0, &dummy) err = s.FinishInitParams(0, &dummy)
if err != pserver.ErrAlreadyIntialized { if err != pserver.ErrAlreadyInitialized {
t.FailNow() t.FailNow()
} }
err = s.BeginInitParams(nil, &dummy) err = s.BeginInitParams(nil, &dummy)
if err != pserver.ErrAlreadyIntialized { if err != pserver.ErrAlreadyInitialized {
t.FailNow()
}
}
func TestUninitialized(t *testing.T) {
s := pserver.NewService()
var dummy int
err := s.SendGrads(nil, &dummy)
if err != pserver.ErrUnintialized {
t.FailNow() t.FailNow()
} }
} }
...@@ -140,6 +131,16 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -140,6 +131,16 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Done() wg.Done()
}() }()
wg.Add(1)
go func() {
var dummy int
err := s.SendGrads(nil, &dummy)
if err != nil {
t.FailNow()
}
wg.Done()
}()
var dummy int var dummy int
err := s.BeginInitParams(nil, &dummy) err := s.BeginInitParams(nil, &dummy)
if err != nil { if err != nil {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册