diff --git a/go/pserver/cclient/cclient.go b/go/pserver/cclient/cclient.go index 6aaaff7409dfcf500a24496b0d11c3eae3eb9348..92a41b7f5434842c6318704dd85adf9e51c19944 100644 --- a/go/pserver/cclient/cclient.go +++ b/go/pserver/cclient/cclient.go @@ -123,8 +123,9 @@ func paddle_begin_init_params(client C.paddle_pserver_client) C.int { func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int { et := pserver.ElementType(param.element_type) name := C.GoString(param.name) + content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len)) pc := pserver.ParameterWithConfig{ - Param: pserver.Parameter{Name: name, ElementType: et, Content: param.content, Length: para.content_len}, + Param: pserver.Parameter{Name: name, ElementType: et, Content: content}, Config: cArrayToSlice(param_config, int(config_len)), } c := get(client) @@ -166,7 +167,8 @@ func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient grad := *(**C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads)))) et := pserver.ElementType(grad.element_type) name := C.GoString(grad.name) - gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: grad.content, Length: grad.content_len}) + content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len)) + gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content}) } c := get(client) @@ -223,14 +225,14 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, } if unsafe.Pointer(param.content) != nullPtr { - if int(param.content_len) != p.Length { + if int(param.content_len) != len(p.Content) { log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content)) return C.PSERVER_ERROR } } - C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(p.Content), C.size_t(p.Length)) - param.content_len = C.int(p.Length) + C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) + param.content_len = C.int(len(p.Content)) param.element_type = C.paddle_element_type(p.ElementType) } diff --git a/go/pserver/cclient/test/test_cclient.c b/go/pserver/cclient/test/test_cclient.c index 7d26127b600e0110e8d2ae0e6c514a006efdcd5c..5bd4913ba3cabfbd0988a5015029a96c15ec2793 100644 --- a/go/pserver/cclient/test/test_cclient.c +++ b/go/pserver/cclient/test/test_cclient.c @@ -50,10 +50,10 @@ void getParams(paddle_pserver_client c) { int main() { char addr[] = "localhost:3000"; paddle_pserver_client c = paddle_new_pserver_client(addr, 1); - char config_proto[1024]; + char *config_proto; size_t config_proto_len = 0; ssize_t nread; - FILE *fp = fopen("optimizer.pb.txt", "r"); + FILE *fp = fopen("testdata/optimizer.pb.txt", "r"); if(!fp) { fail(); } while((nread = getline(&config_proto, &config_proto_len, fp)) != -1) { printf("%s", config_proto); @@ -70,7 +70,7 @@ retry: param.name = name_a; param.content = content_a; param.content_len = 2000; - int error = paddle_init_param(c, param, config_proto, config_proto_len); + int error = paddle_init_param(c, param, (void *)config_proto, config_proto_len); if (error != 0) { goto retry; } @@ -79,7 +79,7 @@ retry: param.name = name_b; param.content = content_b; param.content_len = 3000; - error = paddle_init_param(c, param, NULL, 0); + error = paddle_init_param(c, param, (void *)config_proto, config_proto_len); if (error != 0) { goto retry; } diff --git a/go/pserver/client_test.go b/go/pserver/client_test.go index c5d38e41129dc6297f639becf14809e716019c83..d0371a26a13fac9daecacd0b6a271caa6d830651 100644 --- a/go/pserver/client_test.go +++ b/go/pserver/client_test.go @@ -75,9 +75,7 @@ func TestClientFull(t *testing.T) { var p pserver.Parameter p.Name = "p_" + strconv.Itoa(i) p.ElementType = pserver.Float32 - ElementValue := make([]byte, (i+1)*100) - p.Content = &ElementValue[0] - p.Length = len(ElementValue) + p.Content = make([]byte, (i+1)*100) err := c.InitParam(pserver.ParameterWithConfig{Param: p}) if err != nil { t.Fatal(err) @@ -94,9 +92,7 @@ func TestClientFull(t *testing.T) { var g pserver.Gradient g.Name = "p_" + strconv.Itoa(i) g.ElementType = pserver.Float32 - ElementValue := make([]byte, (i+1)*100) - g.Content = &ElementValue[0] - g.Length = len(ElementValue) + g.Content = make([]byte, (i+1)*100) grads = append(grads, g) } diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index 12bf055b4deb2933701ed4c8b4d058e1072ace99..4ecae0911c2d1bbbc5258ab48d4140e52d9156ca 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -4,7 +4,7 @@ package pserver // TODO(zhihong): move compile flags to cmake go_library #cgo pkg-config: protobuf #cgo CFLAGS: -I ../../ -#cgo LDFLAGS: /Users/dzh/.go/src/github.com/PaddlePaddle/Paddle/build/go/pserver/cclient/libpaddle_go_optimizer.a +#cgo LDFLAGS: /Users/dzh/.go/src/github.com/PaddlePaddle/Paddle/build/go/pserver/cclient/libpaddle_go_optimizer.a -lstdc++ #include "paddle/optimizer/optimizer.h" */ import "C" @@ -38,17 +38,20 @@ func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { o := &optimizer{} p := paramWithConfigs.Param c := paramWithConfigs.Config - buffer := &p.Content[0] - o.opt = C.paddle_create_optimizer(C.uchar(c), C.int(len(c)), unsafe.Pointer(buffer), C.int(len(p.Content)), nullPtr, 0) + var cbuffer unsafe.Pointer + cbuffer = unsafe.Pointer(&p.Content[0]) + o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), + C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)), + (*C.char)(nullPtr), 0) return o } func (o *optimizer) GetWeights(p *Parameter) error { var buffer unsafe.Pointer - buffer_len := C.paddle_optimizer_get_weights(unsafe.Pointer(o), &buffer) + buffer_len := C.paddle_optimizer_get_weights(o.opt, &buffer) if buffer_len == 0 || buffer == nullPtr { - return fmt.Errorf("parameter optimizer error : %s get failed", p.name) + return fmt.Errorf("parameter optimizer error : %s get failed", p.Name) } p.Content = cArrayToSlice(buffer, int(buffer_len)) return nil @@ -60,7 +63,9 @@ func (o *optimizer) UpdateParameter(g Gradient) error { } // FIXME: do we need a copy? discard g.Content by GC ok - r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), unsafe.Pointer(g.Content), C.int(len(g.Content))) + var cbuffer unsafe.Pointer + cbuffer = unsafe.Pointer(&g.Content[0]) + r := C.paddle_update_parameter(o.opt, C.paddle_element_type(g.ElementType), cbuffer, C.int(len(g.Content))) if r != 0 { return fmt.Errorf("optimizer update returned error code: %d", r) } diff --git a/go/pserver/optimizer_test.go b/go/pserver/optimizer_test.go index eac744b5cdb1be4f5fe593add8d836a78c9c6224..368047d6f89e080016909efbc5bd090c42530bfd 100644 --- a/go/pserver/optimizer_test.go +++ b/go/pserver/optimizer_test.go @@ -8,11 +8,13 @@ import ( func TestOptimizerCreateRelease(t *testing.T) { p := Parameter{ Name: "a", - ElementType: Float32, + ElementType: Int32, } - p.Content = []byte{0.1, 0.3} + p.Content = []byte{1, 3} config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb.txt") - + if err != nil { + t.Fatalf("read optimizer proto failed") + } param := ParameterWithConfig{ Param: p, Config: config, diff --git a/go/pserver/service.go b/go/pserver/service.go index d0d57136b5ef1d5378cb8efef19dcf6a5755d88a..cdd433260af0d3ee2eee4df2cfc270f65e395e69 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -128,11 +128,11 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { // nature. This race condition is allowed deliberately // to save the program from making a copy of the // paramter content. - p.Name = name - p.ElementType = opt.ElementType + parameter.Name = name + parameter.ElementType = opt.ElementType - ok := opt.GetWeights(¶meter) - return ok + err := opt.GetWeights(parameter) + return err } // Save tells the parameter server to save parameters.