提交 d08c8ea7 编写于 作者: H Helin Wang

change interface for clearity

上级 e21d56ee
...@@ -150,7 +150,7 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C ...@@ -150,7 +150,7 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C
} }
//export paddle_get_params //export paddle_get_params
func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, reuse C.int, total C.int) C.int { func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter, total C.int) C.int {
var ns []string var ns []string
for i := 0; i < int(total); i++ { for i := 0; i < int(total); i++ {
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names)))) name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names))))
...@@ -169,30 +169,36 @@ func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, ...@@ -169,30 +169,36 @@ func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter,
} }
p := ps[i] p := ps[i]
name := C.CString(p.Name) param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
param := (*C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) nameReady := false
contentAllocated := false
if reuse != 0 && unsafe.Pointer(param.name) != nullPtr {
if n := C.GoString(param.name); n != p.Name { if unsafe.Pointer(param) == nullPtr {
log.Println("warning: pre-allocated parameter name not match parameter name, pre-allocated parameter name will be freed.", n, p.Name) param = (*C.paddle_parameter)(C.calloc(1, C.size_t(unsafe.Sizeof(*param))))
C.free(unsafe.Pointer(param.name))
param.name = name
}
} else { } else {
param.name = name if unsafe.Pointer(param.name) != nullPtr {
} if n := C.GoString(param.name); n != p.Name {
log.Println("warning: pre-allocated parameter name not match parameter name, pre-allocated parameter name will be freed.", n, p.Name)
C.free(unsafe.Pointer(param.name))
} else {
nameReady = true
}
}
memReady := false if param.content != nullPtr {
if reuse != 0 && param.content != nullPtr { if int(param.content_len) == len(p.Content) {
if int(param.content_len) < len(p.Content) { contentAllocated = true
memReady = true } else {
} else { log.Println("warning: pre-allocated content len does not match parameter content len, pre-allocated content will be freed.", param.content_len, len(p.Content))
log.Println("warning: pre-allocated content len is smaller than parameter content len, pre-allocated content will be freed.", param.content_len, len(p.Content)) C.free(param.content)
C.free(param.content) }
} }
} }
if !memReady { if !nameReady {
param.name = C.CString(p.Name)
}
if !contentAllocated {
param.content = C.malloc(C.size_t(len(p.Content))) param.content = C.malloc(C.size_t(len(p.Content)))
} }
C.memcpy(param.content, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) C.memcpy(param.content, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
......
...@@ -45,9 +45,9 @@ int main() { ...@@ -45,9 +45,9 @@ int main() {
panic(); panic();
} }
paddle_parameter params[2]; paddle_parameter* params[2] = {NULL, NULL};
char* names[]={"param_a", "param_b"}; char* names[]={"param_a", "param_b"};
if (!paddle_get_params(c, names, params, 0, 2)) { if (!paddle_get_params(c, names, params, 2)) {
panic(); panic();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册