diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index 4ecae0911c2d1bbbc5258ab48d4140e52d9156ca..af7faad25495bd200e979485e4ca6a8b4170e0fd 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -47,7 +47,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { } func (o *optimizer) GetWeights(p *Parameter) error { - + // FIXME: get weigths from optimizer has bug var buffer unsafe.Pointer buffer_len := C.paddle_optimizer_get_weights(o.opt, &buffer) if buffer_len == 0 || buffer == nullPtr { @@ -59,7 +59,7 @@ func (o *optimizer) GetWeights(p *Parameter) error { func (o *optimizer) UpdateParameter(g Gradient) error { if o.ElementType != g.ElementType { - return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, g.ElementType, g.ElementType) + return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.ElementType, g.ElementType) } // FIXME: do we need a copy? discard g.Content by GC ok diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index a88e2df73adaafe2ac71a178ca7ecba4cb24e9b3..a09b25dec0866f55efaadd75fdbf3d727c0f6d4e 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -2,6 +2,7 @@ package pserver_test import ( "io/ioutil" + "reflect" "sync" "testing" "time" @@ -9,7 +10,7 @@ import ( "github.com/PaddlePaddle/Paddle/go/pserver" ) -func TestFull(t *testing.T) { +func TestNewName(t *testing.T) { s := pserver.NewService() var p pserver.Parameter p.Name = "param_a" @@ -25,69 +26,69 @@ func TestFull(t *testing.T) { t.FailNow() } - // var p1 pserver.Parameter - // p1.Name = "param_b" - // p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - // p1.ElementType = pserver.Float32 - // fmt.Println("paddle passed") - // err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil) - // if err != nil { - // t.FailNow() - // } - - // err = s.FinishInitParams(0, nil) - // if err != nil { - // t.FailNow() - // } - - // var param pserver.Parameter - // err = s.GetParam("param_b", ¶m) - // if err != nil { - // t.FailNow() - // } - - // if !reflect.DeepEqual(param, p1) { - // t.FailNow() - // } - - // g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) - // err = s.SendGrad(g1, nil) - // if err != nil { - // t.FailNow() - // } - // err = s.SendGrad(g2, nil) - - // if err != nil { - // t.FailNow() - // } - - // var param1 pserver.Parameter - // err = s.GetParam("param_a", ¶m1) - // if err != nil { - // t.FailNow() - // } - - // // don't compare content, since it's already changed by - // // gradient update. - // param1.Content = nil - // p.Content = nil - - // if !reflect.DeepEqual(param1, p) { - // t.FailNow() - // } - // } - - // func TestMultipleInit(t *testing.T) { - // s := pserver.NewService() - // err := s.FinishInitParams(0, nil) - // if err != nil { - // t.FailNow() - // } - - // err = s.FinishInitParams(0, nil) - // if err.Error() != pserver.AlreadyInitialized { - // t.FailNow() - // } + var p1 pserver.Parameter + p1.Name = "param_b" + p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + p1.ElementType = pserver.Float32 + err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: config}, nil) + if err != nil { + t.FailNow() + } + + err = s.FinishInitParams(0, nil) + if err != nil { + t.FailNow() + } + + var param pserver.Parameter + err = s.GetParam("param_b", ¶m) + if err != nil { + t.FailNow() + } + + if !reflect.DeepEqual(param, p1) { + t.FailNow() + } + + g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) + + err = s.SendGrad(g1, nil) + if err != nil { + t.FailNow() + } + err = s.SendGrad(g2, nil) + + if err != nil { + t.FailNow() + } + + var param1 pserver.Parameter + err = s.GetParam("param_a", ¶m1) + if err != nil { + t.FailNow() + } + + // don't compare content, since it's already changed by + // gradient update. + param1.Content = nil + p.Content = nil + + if !reflect.DeepEqual(param1, p) { + t.FailNow() + } +} + +func TestMultipleInit(t *testing.T) { + s := pserver.NewService() + err := s.FinishInitParams(0, nil) + if err != nil { + t.FailNow() + } + + err = s.FinishInitParams(0, nil) + if err.Error() != pserver.AlreadyInitialized { + t.FailNow() + } } func TestUninitialized(t *testing.T) {