diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index 070896f7c7611f2429e77df030d1511a34f95b0f..b4a040f46bff5c25b193d41e5d36b59762891574 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -49,7 +49,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { cbuffer = C.malloc(C.size_t(len(p.Content))) C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), - C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)*C.sizeof_float), + C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(nullPtr), 0) return o } @@ -57,7 +57,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { func (o *optimizer) GetWeights() []byte { var buffer unsafe.Pointer buffer_len := C.paddle_optimizer_get_weights(o.opt, &buffer) - return cArrayToSlice(buffer, int(buffer_len)) + return cArrayToSlice(buffer, int(buffer_len)*C.sizeof_float) } func (o *optimizer) UpdateParameter(g Gradient) error { @@ -65,7 +65,7 @@ func (o *optimizer) UpdateParameter(g Gradient) error { return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType) } - r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))) + r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))/C.sizeof_float) if r != 0 { return fmt.Errorf("optimizer update returned error code: %d", r) } diff --git a/go/pserver/optimizer_test.go b/go/pserver/optimizer_test.go index 49d9df5898dc69a8e50c4a712e9b04467b45a7b8..368047d6f89e080016909efbc5bd090c42530bfd 100644 --- a/go/pserver/optimizer_test.go +++ b/go/pserver/optimizer_test.go @@ -2,7 +2,6 @@ package pserver import ( "io/ioutil" - "reflect" "testing" ) @@ -23,26 +22,3 @@ func TestOptimizerCreateRelease(t *testing.T) { o := newOptimizer(param) o.Cleanup() } - -func TestOptimizerFull(t *testing.T) { - p := Parameter{ - Name: "a", - ElementType: Float32, - } - p.Content = []byte{1, 3} - config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt") - if err != nil { - t.Fatalf("read optimizer proto failed") - } - param := ParameterWithConfig{ - Param: p, - Config: config, - } - o := newOptimizer(param) - g := Gradient(p) - if !reflect.DeepEqual(p.Content, o.GetWeights()) { - t.FailNow() - } - o.UpdateParameter(g) - o.Cleanup() -} diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index c62f92e09bb3d65aa9552e3624e5d8e1a1945e56..f86619447c28b5be8071d28a127b48768939261b 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -10,8 +10,7 @@ import ( "github.com/PaddlePaddle/Paddle/go/pserver" ) - -func TestFull(t *testing.T) { +func TestServiceFull(t *testing.T) { s, err := pserver.NewService(0) if err != nil { t.Error(err)