diff --git a/paddle/go/pserver/lib/client/main.go b/paddle/go/pserver/lib/client/main.go index 575f536ba2e1fb82517b30090ada61e2c723dc7f..9040d24611ae2c640f6bd5ab2f8334dc93fda1d3 100644 --- a/paddle/go/pserver/lib/client/main.go +++ b/paddle/go/pserver/lib/client/main.go @@ -31,10 +31,7 @@ import ( "github.com/PaddlePaddle/Paddle/paddle/go/pserver" ) -const ( - ptrSize = unsafe.Sizeof(uintptr(0)) -) - +var nullPtr = unsafe.Pointer(uintptr(0)) var mu sync.Mutex var handleMap = make(map[C.client]*pserver.Client) var curHandle C.client @@ -63,6 +60,10 @@ func remove(client C.client) *pserver.Client { } func cArrayToSlice(p unsafe.Pointer, len int) []byte { + if p == nullPtr { + return nil + } + // create a Go clice backed by a C array, // reference: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices return (*[1 << 30]byte)(p)[:len:len] @@ -131,7 +132,7 @@ func paddle_finish_init_params(client C.client) C.int { func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int { var gs []pserver.Gradient for i := 0; i < int(total); i++ { - grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*ptrSize))) + grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads)))) et := pserver.ElementType(grad.element_type) name := C.GoString(grad.name) content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len)) @@ -152,7 +153,7 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, total C.int) C.int { var ns []string for i := 0; i < int(total); i++ { - name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*ptrSize))) + name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names)))) ns = append(ns, C.GoString(name)) } c := get(client) @@ -169,18 +170,24 @@ func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, p := ps[i] name := C.CString(p.Name) - param := (*C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*ptrSize))) + param := (*C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) - if unsafe.Pointer(param.name) != unsafe.Pointer(uintptr(0)) { - C.free(unsafe.Pointer(param.name)) + if unsafe.Pointer(param.name) != nullPtr { + if n := C.GoString(param.name); n != p.Name { + log.Println("warning: pre-allocated parameter name not match parameter name, pre-allocated parameter name will be freed.", n, p.Name) + C.free(unsafe.Pointer(param.name)) + param.name = name + } + } else { + param.name = name } - param.name = name memReady := false - if param.content != unsafe.Pointer(uintptr(0)) { - if int(param.content_len) == len(p.Content) { + if param.content != nullPtr { + if int(param.content_len) < len(p.Content) { memReady = true } else { + log.Println("warning: pre-allocated content len is smaller than parameter content len, pre-allocated content will be freed.", param.content_len, len(p.Content)) C.free(param.content) } } diff --git a/paddle/go/pserver/lib/client/test/main.c b/paddle/go/pserver/lib/client/test/main.c index 7bf89eddf93b4eff3faec8f295cfc1484a601768..f27e57508c2e292380dc666fb019a504dadac9b0 100644 --- a/paddle/go/pserver/lib/client/test/main.c +++ b/paddle/go/pserver/lib/client/test/main.c @@ -2,7 +2,56 @@ //#include "gtest/gtest.h" +void panic() { + // TODO(helin): fix: gtest using cmake is not working, using this + // hacky way for now. + *(void*)0; +} + int main() { - client c = paddle_new_pserver_client(NULL); + char addr[] = "localhost:3000"; + client c = paddle_new_pserver_client(addr); + retry: + if (paddle_begin_init_params(c, NULL, 0)) { + paddle_parameter param; + char name_a[] = "param_a"; + char name_b[] = "param_b"; + char content[] = {0x00, 0x11, 0x22}; + param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; + param.name = name_a; + param.content = content; + param.content_len = 3; + if (paddle_init_param(c, param, NULL, 0) != 0) { + goto retry; + } + param.element_type = PADDLE_ELEMENT_TYPE_INT32; + param.name = name_b; + param.content = content; + param.content_len = 3; + if (paddle_init_param(c, param, NULL, 0) != 0) { + goto retry; + } + if (paddle_finish_init_params(c) != 0) { + goto retry; + } + } else { + panic(); + } + + char content[] = {0x00, 0x11, 0x22}; + paddle_gradient params[] = {{"param_a", PADDLE_ELEMENT_TYPE_INT32, content, 3}, {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}}; + if (!paddle_send_grads(c, params, 2)) { + panic(); + } + + char* names[]={"param_a", "param_b"}; + if (!paddle_get_params(c, names, params, 2)) { + panic(); + } + + if (!paddle_save_model(c, "/tmp/")) { + panic(); + } + return 0; }