提交 22b5a388 编写于 作者: H Helin Wang

fix according to comments

上级 0e71ab29
......@@ -20,6 +20,8 @@ typedef struct {
} paddle_parameter, paddle_gradient;
typedef int paddle_pserver_client;
#define PSERVER_ERROR -1
#define PSERVER_OK 0
*/
import "C"
......@@ -115,7 +117,7 @@ func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
if selected := c.BeginInitParams(); selected {
return 1
}
return 0
return C.PSERVER_OK
}
//export paddle_init_param
......@@ -133,13 +135,13 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter,
if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Printf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name)
return 0
return C.PSERVER_OK
}
log.Println(err)
return -1
return C.PSERVER_ERROR
}
return 0
return C.PSERVER_OK
}
//export paddle_finish_init_params
......@@ -149,14 +151,14 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Println("parameters already initialized, treat paddle_finish_init_params as sucessful.")
return 0
return C.PSERVER_OK
}
log.Println(err)
return -1
return C.PSERVER_ERROR
}
return 0
return C.PSERVER_OK
}
//export paddle_send_grads
......@@ -174,10 +176,10 @@ func paddle_send_grads(client C.paddle_pserver_client, grads *C.paddle_gradient,
err := c.SendGrads(gs)
if err != nil {
log.Println(err)
return -1
return C.PSERVER_ERROR
}
return 0
return C.PSERVER_OK
}
//export paddle_get_params
......@@ -191,42 +193,26 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
ps, err := c.GetParams(ns)
if err != nil {
log.Println(err)
return -1
}
names := func() (string, string) {
retNames := ""
for _, p := range ps {
if retNames == "" {
retNames = p.Name
} else {
retNames = ", " + p.Name
}
}
requestedNames := ""
for _, n := range ns {
if requestedNames == "" {
requestedNames = n
} else {
requestedNames = ", " + n
}
}
return requestedNames, retNames
return C.PSERVER_ERROR
}
if len(ps) != len(ns) {
requestedNames, retNames := names()
log.Printf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", retNames, requestedNames)
return -1
pn := make([]string, len(ps))
for i, p := range ps {
pn[i] = p.Name
}
log.Printf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR
}
for i := range ps {
if ns[i] != ps[i].Name {
requestedNames, retNames := names()
log.Printf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", retNames, requestedNames)
return -1
pn := make([]string, len(ps))
for i, p := range ps {
pn[i] = p.Name
}
log.Printf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR
}
}
......@@ -236,12 +222,12 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
if unsafe.Pointer(param) == nullPtr {
log.Println("must pre-allocate parameter.")
return -1
return C.PSERVER_ERROR
} else {
if unsafe.Pointer(param.content) != nullPtr {
if int(param.content_len) != len(p.Content) {
log.Printf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
return -1
return C.PSERVER_ERROR
}
}
}
......@@ -251,7 +237,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
param.element_type = C.paddle_element_type(p.ElementType)
}
return 0
return C.PSERVER_OK
}
//export paddle_save_model
......@@ -261,10 +247,10 @@ func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
err := c.Save(p)
if err != nil {
log.Println(err)
return -1
return C.PSERVER_ERROR
}
return 0
return C.PSERVER_OK
}
func main() {} // Required but ignored
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册