diff --git a/paddle/go/pserver/lib/client/main.go b/paddle/go/pserver/lib/client/main.go index ec61398ef454891cfa681da418c361cd5f5577fa..92aa54bd1090e83572cf56084fd84815cda5313b 100644 --- a/paddle/go/pserver/lib/client/main.go +++ b/paddle/go/pserver/lib/client/main.go @@ -150,7 +150,7 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C } //export paddle_get_params -func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, reuse C.int, total C.int) C.int { +func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter, total C.int) C.int { var ns []string for i := 0; i < int(total); i++ { name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names)))) @@ -169,30 +169,36 @@ func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, } p := ps[i] - name := C.CString(p.Name) - param := (*C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) - - if reuse != 0 && unsafe.Pointer(param.name) != nullPtr { - if n := C.GoString(param.name); n != p.Name { - log.Println("warning: pre-allocated parameter name not match parameter name, pre-allocated parameter name will be freed.", n, p.Name) - C.free(unsafe.Pointer(param.name)) - param.name = name - } + param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) + nameReady := false + contentAllocated := false + + if unsafe.Pointer(param) == nullPtr { + param = (*C.paddle_parameter)(C.calloc(1, C.size_t(unsafe.Sizeof(*param)))) } else { - param.name = name - } + if unsafe.Pointer(param.name) != nullPtr { + if n := C.GoString(param.name); n != p.Name { + log.Println("warning: pre-allocated parameter name not match parameter name, pre-allocated parameter name will be freed.", n, p.Name) + C.free(unsafe.Pointer(param.name)) + } else { + nameReady = true + } + } - memReady := false - if reuse != 0 && param.content != nullPtr { - if int(param.content_len) < len(p.Content) { - memReady = true - } else { - log.Println("warning: pre-allocated content len is smaller than parameter content len, pre-allocated content will be freed.", param.content_len, len(p.Content)) - C.free(param.content) + if param.content != nullPtr { + if int(param.content_len) == len(p.Content) { + contentAllocated = true + } else { + log.Println("warning: pre-allocated content len does not match parameter content len, pre-allocated content will be freed.", param.content_len, len(p.Content)) + C.free(param.content) + } } } - if !memReady { + if !nameReady { + param.name = C.CString(p.Name) + } + if !contentAllocated { param.content = C.malloc(C.size_t(len(p.Content))) } C.memcpy(param.content, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) diff --git a/paddle/go/pserver/lib/client/test/main.c b/paddle/go/pserver/lib/client/test/main.c index d9d9e997017c0c036c0994359940fdd4d1ad9d34..6fec3a7a83a616b37f8c12a726d58951c369c6ca 100644 --- a/paddle/go/pserver/lib/client/test/main.c +++ b/paddle/go/pserver/lib/client/test/main.c @@ -45,9 +45,9 @@ int main() { panic(); } - paddle_parameter params[2]; + paddle_parameter* params[2] = {NULL, NULL}; char* names[]={"param_a", "param_b"}; - if (!paddle_get_params(c, names, params, 0, 2)) { + if (!paddle_get_params(c, names, params, 2)) { panic(); }