提交 a5091ef2 编写于 作者: H Helin Wang

fix bug, add functional test

上级 6011b2e7
......@@ -31,10 +31,7 @@ import (
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)
const (
ptrSize = unsafe.Sizeof(uintptr(0))
)
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.client]*pserver.Client)
var curHandle C.client
......@@ -63,6 +60,10 @@ func remove(client C.client) *pserver.Client {
}
func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
return nil
}
// create a Go clice backed by a C array,
// reference: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
return (*[1 << 30]byte)(p)[:len:len]
......@@ -131,7 +132,7 @@ func paddle_finish_init_params(client C.client) C.int {
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int {
var gs []pserver.Gradient
for i := 0; i < int(total); i++ {
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*ptrSize)))
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
et := pserver.ElementType(grad.element_type)
name := C.GoString(grad.name)
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
......@@ -152,7 +153,7 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C
func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, total C.int) C.int {
var ns []string
for i := 0; i < int(total); i++ {
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*ptrSize)))
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names))))
ns = append(ns, C.GoString(name))
}
c := get(client)
......@@ -169,18 +170,24 @@ func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter,
p := ps[i]
name := C.CString(p.Name)
param := (*C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*ptrSize)))
param := (*C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
if unsafe.Pointer(param.name) != unsafe.Pointer(uintptr(0)) {
if unsafe.Pointer(param.name) != nullPtr {
if n := C.GoString(param.name); n != p.Name {
log.Println("warning: pre-allocated parameter name not match parameter name, pre-allocated parameter name will be freed.", n, p.Name)
C.free(unsafe.Pointer(param.name))
param.name = name
}
} else {
param.name = name
}
memReady := false
if param.content != unsafe.Pointer(uintptr(0)) {
if int(param.content_len) == len(p.Content) {
if param.content != nullPtr {
if int(param.content_len) < len(p.Content) {
memReady = true
} else {
log.Println("warning: pre-allocated content len is smaller than parameter content len, pre-allocated content will be freed.", param.content_len, len(p.Content))
C.free(param.content)
}
}
......
......@@ -2,7 +2,56 @@
//#include "gtest/gtest.h"
void panic() {
// TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now.
*(void*)0;
}
int main() {
client c = paddle_new_pserver_client(NULL);
char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr);
retry:
if (paddle_begin_init_params(c, NULL, 0)) {
paddle_parameter param;
char name_a[] = "param_a";
char name_b[] = "param_b";
char content[] = {0x00, 0x11, 0x22};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_a;
param.content = content;
param.content_len = 3;
if (paddle_init_param(c, param, NULL, 0) != 0) {
goto retry;
}
param.element_type = PADDLE_ELEMENT_TYPE_INT32;
param.name = name_b;
param.content = content;
param.content_len = 3;
if (paddle_init_param(c, param, NULL, 0) != 0) {
goto retry;
}
if (paddle_finish_init_params(c) != 0) {
goto retry;
}
} else {
panic();
}
char content[] = {0x00, 0x11, 0x22};
paddle_gradient params[] = {{"param_a", PADDLE_ELEMENT_TYPE_INT32, content, 3}, {"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}};
if (!paddle_send_grads(c, params, 2)) {
panic();
}
char* names[]={"param_a", "param_b"};
if (!paddle_get_params(c, names, params, 2)) {
panic();
}
if (!paddle_save_model(c, "/tmp/")) {
panic();
}
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册