diff --git a/paddle/go/pserver/lib/client/main.go b/paddle/go/pserver/lib/client/main.go index 92aa54bd1090e83572cf56084fd84815cda5313b..8aa757edcdc1c012ac6d5f659930211a3cfccf23 100644 --- a/paddle/go/pserver/lib/client/main.go +++ b/paddle/go/pserver/lib/client/main.go @@ -15,10 +15,24 @@ typedef enum { typedef struct { char* name; paddle_element_type element_type; - void* content; + char* content; int content_len; } paddle_parameter, paddle_gradient; +static inline void paddle_release_param(paddle_parameter* param) { + if (param != NULL) { + if (param->name != NULL) { + free(param->name); + } + + if (param->content != NULL) { + free(param->content); + } + + free(param); + } +} + typedef int client; */ import "C" @@ -185,12 +199,12 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter } } - if param.content != nullPtr { + if unsafe.Pointer(param.content) != nullPtr { if int(param.content_len) == len(p.Content) { contentAllocated = true } else { log.Println("warning: pre-allocated content len does not match parameter content len, pre-allocated content will be freed.", param.content_len, len(p.Content)) - C.free(param.content) + C.free(unsafe.Pointer(param.content)) } } } @@ -199,9 +213,9 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter param.name = C.CString(p.Name) } if !contentAllocated { - param.content = C.malloc(C.size_t(len(p.Content))) + param.content = (*C.char)(C.malloc(C.size_t(len(p.Content)))) } - C.memcpy(param.content, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) + C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) param.content_len = C.int(len(p.Content)) param.element_type = C.paddle_element_type(p.ElementType) } diff --git a/paddle/go/pserver/lib/client/test/main.c b/paddle/go/pserver/lib/client/test/main.c index 14d4522eac0b5908724e75c370f6ea8d78e6f60e..619661d4ac3508faf5f2216ecb50be0ecbd3e154 100644 --- a/paddle/go/pserver/lib/client/test/main.c +++ b/paddle/go/pserver/lib/client/test/main.c @@ -8,20 +8,6 @@ void panic() { *(void*)0; } -void releaseParam(paddle_parameter* param) { - if (param != NULL) { - if (param->name != NULL) { - free(param->name); - } - - if (param->content != NULL) { - free(param->content); - } - - free(param); - } -} - int main() { char addr[] = "localhost:3000"; client c = paddle_new_pserver_client(addr); @@ -65,8 +51,8 @@ int main() { panic(); } - releaseParam(params[0]); - releaseParam(params[1]); + paddle_release_param(params[0]); + paddle_release_param(params[1]); if (!paddle_save_model(c, "/tmp/")) { panic();