提交 a5091ef2 编写于 作者: H Helin Wang

fix bug, add functional test

上级 6011b2e7
...@@ -31,10 +31,7 @@ import ( ...@@ -31,10 +31,7 @@ import (
"github.com/PaddlePaddle/Paddle/paddle/go/pserver" "github.com/PaddlePaddle/Paddle/paddle/go/pserver"
) )
const ( var nullPtr = unsafe.Pointer(uintptr(0))
ptrSize = unsafe.Sizeof(uintptr(0))
)
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.client]*pserver.Client) var handleMap = make(map[C.client]*pserver.Client)
var curHandle C.client var curHandle C.client
...@@ -63,6 +60,10 @@ func remove(client C.client) *pserver.Client { ...@@ -63,6 +60,10 @@ func remove(client C.client) *pserver.Client {
} }
func cArrayToSlice(p unsafe.Pointer, len int) []byte { func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
return nil
}
// create a Go clice backed by a C array, // create a Go clice backed by a C array,
// reference: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices // reference: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
return (*[1 << 30]byte)(p)[:len:len] return (*[1 << 30]byte)(p)[:len:len]
...@@ -131,7 +132,7 @@ func paddle_finish_init_params(client C.client) C.int { ...@@ -131,7 +132,7 @@ func paddle_finish_init_params(client C.client) C.int {
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int { func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int {
var gs []pserver.Gradient var gs []pserver.Gradient
for i := 0; i < int(total); i++ { for i := 0; i < int(total); i++ {
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*ptrSize))) grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
et := pserver.ElementType(grad.element_type) et := pserver.ElementType(grad.element_type)
name := C.GoString(grad.name) name := C.GoString(grad.name)
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len)) content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
...@@ -152,7 +153,7 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C ...@@ -152,7 +153,7 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C
func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, total C.int) C.int { func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, total C.int) C.int {
var ns []string var ns []string
for i := 0; i < int(total); i++ { for i := 0; i < int(total); i++ {
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*ptrSize))) name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names))))
ns = append(ns, C.GoString(name)) ns = append(ns, C.GoString(name))
} }
c := get(client) c := get(client)
...@@ -169,18 +170,24 @@ func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, ...@@ -169,18 +170,24 @@ func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter,
p := ps[i] p := ps[i]
name := C.CString(p.Name) name := C.CString(p.Name)
param := (*C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*ptrSize))) param := (*C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
if unsafe.Pointer(param.name) != unsafe.Pointer(uintptr(0)) { if unsafe.Pointer(param.name) != nullPtr {
if n := C.GoString(param.name); n != p.Name {
log.Println("warning: pre-allocated parameter name not match parameter name, pre-allocated parameter name will be freed.", n, p.Name)
C.free(unsafe.Pointer(param.name)) C.free(unsafe.Pointer(param.name))
param.name = name
} }
} else {
param.name = name param.name = name
}
memReady := false memReady := false
if param.content != unsafe.Pointer(uintptr(0)) { if param.content != nullPtr {
if int(param.content_len) == len(p.Content) { if int(param.content_len) < len(p.Content) {
memReady = true memReady = true
} else { } else {
log.Println("warning: pre-allocated content len is smaller than parameter content len, pre-allocated content will be freed.", param.content_len, len(p.Content))
C.free(param.content) C.free(param.content)
} }
} }
......
...@@ -2,7 +2,56 @@ ...@@ -2,7 +2,56 @@
//#include "gtest/gtest.h" //#include "gtest/gtest.h"
void panic() {
// TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now.
*(void*)0;
}
int main() { int main() {
client c = paddle_new_pserver_client(NULL); char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr);
retry:
if (paddle_begin_init_params(c, NULL, 0)) {
paddle_parameter param;
char name_a[] = "param_a";
char name_b[] = "param_b";
char content[] = {0x00, 0x11, 0x22};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_a;
param.content = content;
param.content_len = 3;
if (paddle_init_param(c, param, NULL, 0) != 0) {
goto retry;
}
param.element_type = PADDLE_ELEMENT_TYPE_INT32;
param.name = name_b;
param.content = content;
param.content_len = 3;
if (paddle_init_param(c, param, NULL, 0) != 0) {
goto retry;
}
if (paddle_finish_init_params(c) != 0) {
goto retry;
}
} else {
panic();
}
char content[] = {0x00, 0x11, 0x22};
paddle_gradient params[] = {{"param_a", PADDLE_ELEMENT_TYPE_INT32, content, 3}, {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}};
if (!paddle_send_grads(c, params, 2)) {
panic();
}
char* names[]={"param_a", "param_b"};
if (!paddle_get_params(c, names, params, 2)) {
panic();
}
if (!paddle_save_model(c, "/tmp/")) {
panic();
}
return 0; return 0;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册