提交 22b5a388 编写于 作者: H Helin Wang

fix according to comments

上级 0e71ab29
...@@ -20,6 +20,8 @@ typedef struct { ...@@ -20,6 +20,8 @@ typedef struct {
} paddle_parameter, paddle_gradient; } paddle_parameter, paddle_gradient;
typedef int paddle_pserver_client; typedef int paddle_pserver_client;
#define PSERVER_ERROR -1
#define PSERVER_OK 0
*/ */
import "C" import "C"
...@@ -115,7 +117,7 @@ func paddle_begin_init_params(client C.paddle_pserver_client) C.int { ...@@ -115,7 +117,7 @@ func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
if selected := c.BeginInitParams(); selected { if selected := c.BeginInitParams(); selected {
return 1 return 1
} }
return 0 return C.PSERVER_OK
} }
//export paddle_init_param //export paddle_init_param
...@@ -133,13 +135,13 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, ...@@ -133,13 +135,13 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter,
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Printf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name) log.Printf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name)
return 0 return C.PSERVER_OK
} }
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
//export paddle_finish_init_params //export paddle_finish_init_params
...@@ -149,14 +151,14 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int { ...@@ -149,14 +151,14 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Println("parameters already initialized, treat paddle_finish_init_params as sucessful.") log.Println("parameters already initialized, treat paddle_finish_init_params as sucessful.")
return 0 return C.PSERVER_OK
} }
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
//export paddle_send_grads //export paddle_send_grads
...@@ -174,10 +176,10 @@ func paddle_send_grads(client C.paddle_pserver_client, grads *C.paddle_gradient, ...@@ -174,10 +176,10 @@ func paddle_send_grads(client C.paddle_pserver_client, grads *C.paddle_gradient,
err := c.SendGrads(gs) err := c.SendGrads(gs)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
//export paddle_get_params //export paddle_get_params
...@@ -191,42 +193,26 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -191,42 +193,26 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
ps, err := c.GetParams(ns) ps, err := c.GetParams(ns)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
}
names := func() (string, string) {
retNames := ""
for _, p := range ps {
if retNames == "" {
retNames = p.Name
} else {
retNames = ", " + p.Name
}
}
requestedNames := ""
for _, n := range ns {
if requestedNames == "" {
requestedNames = n
} else {
requestedNames = ", " + n
}
}
return requestedNames, retNames
} }
if len(ps) != len(ns) { if len(ps) != len(ns) {
requestedNames, retNames := names() pn := make([]string, len(ps))
log.Printf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", retNames, requestedNames) for i, p := range ps {
return -1 pn[i] = p.Name
}
log.Printf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR
} }
for i := range ps { for i := range ps {
if ns[i] != ps[i].Name { if ns[i] != ps[i].Name {
requestedNames, retNames := names() pn := make([]string, len(ps))
log.Printf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", retNames, requestedNames) for i, p := range ps {
return -1 pn[i] = p.Name
}
log.Printf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR
} }
} }
...@@ -236,12 +222,12 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -236,12 +222,12 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
if unsafe.Pointer(param) == nullPtr { if unsafe.Pointer(param) == nullPtr {
log.Println("must pre-allocate parameter.") log.Println("must pre-allocate parameter.")
return -1 return C.PSERVER_ERROR
} else { } else {
if unsafe.Pointer(param.content) != nullPtr { if unsafe.Pointer(param.content) != nullPtr {
if int(param.content_len) != len(p.Content) { if int(param.content_len) != len(p.Content) {
log.Printf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content)) log.Printf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
return -1 return C.PSERVER_ERROR
} }
} }
} }
...@@ -251,7 +237,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ...@@ -251,7 +237,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
param.element_type = C.paddle_element_type(p.ElementType) param.element_type = C.paddle_element_type(p.ElementType)
} }
return 0 return C.PSERVER_OK
} }
//export paddle_save_model //export paddle_save_model
...@@ -261,10 +247,10 @@ func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int { ...@@ -261,10 +247,10 @@ func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
err := c.Save(p) err := c.Save(p)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
func main() {} // Required but ignored func main() {} // Required but ignored
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册