diff --git a/go/pserver/cclient/cclient.go b/go/pserver/cclient/cclient.go index 6a76ec920e54cb4b7502ea1e7145eec87f547a5e..e753b461bc724109b6d285f8a8aa90f78a7cfd78 100644 --- a/go/pserver/cclient/cclient.go +++ b/go/pserver/cclient/cclient.go @@ -20,6 +20,8 @@ typedef struct { } paddle_parameter, paddle_gradient; typedef int paddle_pserver_client; +#define PSERVER_ERROR -1 +#define PSERVER_OK 0 */ import "C" @@ -115,7 +117,7 @@ func paddle_begin_init_params(client C.paddle_pserver_client) C.int { if selected := c.BeginInitParams(); selected { return 1 } - return 0 + return C.PSERVER_OK } //export paddle_init_param @@ -133,13 +135,13 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, if err != nil { if err.Error() == pserver.AlreadyInitialized { log.Printf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name) - return 0 + return C.PSERVER_OK } log.Println(err) - return -1 + return C.PSERVER_ERROR } - return 0 + return C.PSERVER_OK } //export paddle_finish_init_params @@ -149,14 +151,14 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int { if err != nil { if err.Error() == pserver.AlreadyInitialized { log.Println("parameters already initialized, treat paddle_finish_init_params as sucessful.") - return 0 + return C.PSERVER_OK } log.Println(err) - return -1 + return C.PSERVER_ERROR } - return 0 + return C.PSERVER_OK } //export paddle_send_grads @@ -174,10 +176,10 @@ func paddle_send_grads(client C.paddle_pserver_client, grads *C.paddle_gradient, err := c.SendGrads(gs) if err != nil { log.Println(err) - return -1 + return C.PSERVER_ERROR } - return 0 + return C.PSERVER_OK } //export paddle_get_params @@ -191,42 +193,26 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, ps, err := c.GetParams(ns) if err != nil { log.Println(err) - return -1 - } - - names := func() (string, string) { - retNames := "" - for _, p := range ps { - if retNames == "" { - retNames = p.Name - } else { - retNames = ", " + p.Name - } - } - - requestedNames := "" - for _, n := range ns { - if requestedNames == "" { - requestedNames = n - } else { - requestedNames = ", " + n - } - } - - return requestedNames, retNames + return C.PSERVER_ERROR } if len(ps) != len(ns) { - requestedNames, retNames := names() - log.Printf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", retNames, requestedNames) - return -1 + pn := make([]string, len(ps)) + for i, p := range ps { + pn[i] = p.Name + } + log.Printf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) + return C.PSERVER_ERROR } for i := range ps { if ns[i] != ps[i].Name { - requestedNames, retNames := names() - log.Printf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", retNames, requestedNames) - return -1 + pn := make([]string, len(ps)) + for i, p := range ps { + pn[i] = p.Name + } + log.Printf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) + return C.PSERVER_ERROR } } @@ -236,12 +222,12 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, if unsafe.Pointer(param) == nullPtr { log.Println("must pre-allocate parameter.") - return -1 + return C.PSERVER_ERROR } else { if unsafe.Pointer(param.content) != nullPtr { if int(param.content_len) != len(p.Content) { log.Printf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content)) - return -1 + return C.PSERVER_ERROR } } } @@ -251,7 +237,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, param.element_type = C.paddle_element_type(p.ElementType) } - return 0 + return C.PSERVER_OK } //export paddle_save_model @@ -261,10 +247,10 @@ func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int { err := c.Save(p) if err != nil { log.Println(err) - return -1 + return C.PSERVER_ERROR } - return 0 + return C.PSERVER_OK } func main() {} // Required but ignored