From e21d56ee3fea86fc9728b9419ee031a5a89faa8c Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Tue, 16 May 2017 16:18:49 -0400 Subject: [PATCH] add reuse variable --- paddle/go/pserver/lib/client/main.go | 6 +++--- paddle/go/pserver/lib/client/test/main.c | 8 +++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/go/pserver/lib/client/main.go b/paddle/go/pserver/lib/client/main.go index 9040d24611a..ec61398ef45 100644 --- a/paddle/go/pserver/lib/client/main.go +++ b/paddle/go/pserver/lib/client/main.go @@ -150,7 +150,7 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C } //export paddle_get_params -func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, total C.int) C.int { +func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, reuse C.int, total C.int) C.int { var ns []string for i := 0; i < int(total); i++ { name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names)))) @@ -172,7 +172,7 @@ func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, name := C.CString(p.Name) param := (*C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) - if unsafe.Pointer(param.name) != nullPtr { + if reuse != 0 && unsafe.Pointer(param.name) != nullPtr { if n := C.GoString(param.name); n != p.Name { log.Println("warning: pre-allocated parameter name not match parameter name, pre-allocated parameter name will be freed.", n, p.Name) C.free(unsafe.Pointer(param.name)) @@ -183,7 +183,7 @@ func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, } memReady := false - if param.content != nullPtr { + if reuse != 0 && param.content != nullPtr { if int(param.content_len) < len(p.Content) { memReady = true } else { diff --git a/paddle/go/pserver/lib/client/test/main.c b/paddle/go/pserver/lib/client/test/main.c index f27e57508c2..d9d9e997017 100644 --- a/paddle/go/pserver/lib/client/test/main.c +++ b/paddle/go/pserver/lib/client/test/main.c @@ -39,13 +39,15 @@ int main() { } char content[] = {0x00, 0x11, 0x22}; - paddle_gradient params[] = {{"param_a", PADDLE_ELEMENT_TYPE_INT32, content, 3}, {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}}; - if (!paddle_send_grads(c, params, 2)) { + paddle_gradient grads[2] = {{"param_a", PADDLE_ELEMENT_TYPE_INT32, content, 3}, {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}}; + + if (!paddle_send_grads(c, grads, 2)) { panic(); } + paddle_parameter params[2]; char* names[]={"param_a", "param_b"}; - if (!paddle_get_params(c, names, params, 2)) { + if (!paddle_get_params(c, names, params, 0, 2)) { panic(); } -- GitLab