提交 0e2acb8b 编写于 作者: 乔龙飞 提交者: GitHub

Merge pull request #2433 from helinwang/pserver_test

Modify pserver client C API, create better test.
...@@ -74,14 +74,25 @@ typedef enum { ...@@ -74,14 +74,25 @@ typedef enum {
typedef struct { typedef struct {
char* name; char* name;
paddle_element_type element_type; paddle_element_type element_type;
void* content; unsigned char* content;
int content_len; int content_len;
} paddle_parameter, paddle_gradient; } paddle_parameter, paddle_gradient;
typedef struct paddle_pserver_client paddle_pserver_client; typedef int paddle_pserver_client;
paddle_pserver_client* paddle_new_pserver_client(); /**
void paddle_pserver_client_release(paddle_pserver_client* client); * @brief creates a pserver client that talks to etcd for coordination.
*/
paddle_pserver_client paddle_new_etcd_pserver_client(char* etcd_addr);
/**
* @brief creates a pserver client given pserver addresses.
*
* @param pserver_addrs comma-separated pserver addresses.
* @param selected if current pserver client is selected to initialize all parameter servers.
*/
paddle_pserver_client paddle_new_pserver_client(char* pserver_addrs, int selected);
void paddle_pserver_client_release(paddle_pserver_client c);
/** /**
* @brief paddle_begin_init_params begins to initialize parameters on * @brief paddle_begin_init_params begins to initialize parameters on
...@@ -95,7 +106,7 @@ void paddle_pserver_client_release(paddle_pserver_client* client); ...@@ -95,7 +106,7 @@ void paddle_pserver_client_release(paddle_pserver_client* client);
* @return 1 if the trainer is selected to initialize parameter * @return 1 if the trainer is selected to initialize parameter
* servers, otherwise 0. * servers, otherwise 0.
*/ */
int paddle_begin_init_params(paddle_pserver_client* client); int paddle_begin_init_params(paddle_pserver_client client);
/** /**
* @brief paddle_init_param initializes the parameter on parameter * @brief paddle_init_param initializes the parameter on parameter
...@@ -109,7 +120,7 @@ int paddle_begin_init_params(paddle_pserver_client* client); ...@@ -109,7 +120,7 @@ int paddle_begin_init_params(paddle_pserver_client* client);
* @paddle_begin_init_param). Or simply exit the program and wait for * @paddle_begin_init_param). Or simply exit the program and wait for
* the cluster management system to restart the trainer. * the cluster management system to restart the trainer.
*/ */
int paddle_init_param(paddle_pserver_client* client, paddle_parameter param, const unsigned char* param_config_proto, int config_len); int paddle_init_param(paddle_pserver_client client, paddle_parameter param, const unsigned char* param_config_proto, int config_len);
/** /**
* @brief paddle_finish_init_params tells parameter servers client has * @brief paddle_finish_init_params tells parameter servers client has
...@@ -120,7 +131,7 @@ int paddle_init_param(paddle_pserver_client* client, paddle_parameter param, con ...@@ -120,7 +131,7 @@ int paddle_init_param(paddle_pserver_client* client, paddle_parameter param, con
* @paddle_begin_init_param). Or simply exit the program and wait for * @paddle_begin_init_param). Or simply exit the program and wait for
* the cluster management system to restart the trainer. * the cluster management system to restart the trainer.
*/ */
int paddle_finish_init_params(paddle_pserver_client* client); int paddle_finish_init_params(paddle_pserver_client client);
/** /**
* @brief paddle_send_grads sends gradients to parameter servers for * @brief paddle_send_grads sends gradients to parameter servers for
...@@ -131,7 +142,7 @@ int paddle_finish_init_params(paddle_pserver_client* client); ...@@ -131,7 +142,7 @@ int paddle_finish_init_params(paddle_pserver_client* client);
* @param learning_rate the learning rate for the gradients. * @param learning_rate the learning rate for the gradients.
* @return 0 if successful, otherwise -1. * @return 0 if successful, otherwise -1.
*/ */
int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grads, int len); int paddle_send_grads(paddle_pserver_client client, const paddle_gradient* grads, int len);
/** /**
* @brief paddle_get_params gets parameters from parameter servers. * @brief paddle_get_params gets parameters from parameter servers.
...@@ -139,13 +150,15 @@ int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grad ...@@ -139,13 +150,15 @@ int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grad
* paddle_get_params will block until parameters are initialized on * paddle_get_params will block until parameters are initialized on
* the parameter servers. * the parameter servers.
* *
* @param names the array of names of the parameters to get. * @param dst the destination array of parameter pointers to save to.
* @param dst the destination array of parameters to save to. * The parameter pointer must be pre-popullated with required parameter name,
* and the content of parameter must be pre-allocated of the size of required
* parameter on pserver.
* @param len the length of the names array and the paddle_parameter * @param len the length of the names array and the paddle_parameter
* array. * array.
* @return 0 if successful, otherwise -1. * @return 0 if successful, otherwise -1.
*/ */
int paddle_get_params(paddle_pserver_client* client, const char** names, paddle_parameter* dst, int len); int paddle_get_params(paddle_pserver_client client, paddle_parameter** dst, int len);
/** /**
* @brief paddle_save_model indicates parameters to save the parameter * @brief paddle_save_model indicates parameters to save the parameter
...@@ -154,5 +167,5 @@ int paddle_get_params(paddle_pserver_client* client, const char** names, paddle_ ...@@ -154,5 +167,5 @@ int paddle_get_params(paddle_pserver_client* client, const char** names, paddle_
* @param path the path to save parameters. * @param path the path to save parameters.
* @return 0 if successful, otherwise -1. * @return 0 if successful, otherwise -1.
*/ */
int paddle_save_model(paddle_pserver_client* client, const char* path); int paddle_save_model(paddle_pserver_client client, const char* path);
``` ```
...@@ -19,21 +19,9 @@ typedef struct { ...@@ -19,21 +19,9 @@ typedef struct {
int content_len; int content_len;
} paddle_parameter, paddle_gradient; } paddle_parameter, paddle_gradient;
static inline void paddle_release_param(paddle_parameter* param) { typedef int paddle_pserver_client;
if (param != NULL) { #define PSERVER_ERROR -1
if (param->name != NULL) { #define PSERVER_OK 0
free(param->name);
}
if (param->content != NULL) {
free(param->content);
}
free(param);
}
}
typedef int client;
*/ */
import "C" import "C"
...@@ -48,10 +36,10 @@ import ( ...@@ -48,10 +36,10 @@ import (
var nullPtr = unsafe.Pointer(uintptr(0)) var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.client]*pserver.Client) var handleMap = make(map[C.paddle_pserver_client]*pserver.Client)
var curHandle C.client var curHandle C.paddle_pserver_client
func add(c *pserver.Client) C.client { func add(c *pserver.Client) C.paddle_pserver_client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
client := curHandle client := curHandle
...@@ -60,13 +48,13 @@ func add(c *pserver.Client) C.client { ...@@ -60,13 +48,13 @@ func add(c *pserver.Client) C.client {
return client return client
} }
func get(client C.client) *pserver.Client { func get(client C.paddle_pserver_client) *pserver.Client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
return handleMap[client] return handleMap[client]
} }
func remove(client C.client) *pserver.Client { func remove(client C.paddle_pserver_client) *pserver.Client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
h := handleMap[client] h := handleMap[client]
...@@ -100,7 +88,7 @@ func (l lister) List() []pserver.Server { ...@@ -100,7 +88,7 @@ func (l lister) List() []pserver.Server {
} }
//export paddle_new_pserver_client //export paddle_new_pserver_client
func paddle_new_pserver_client(addrs *C.char, selected int) C.client { func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
a := C.GoString(addrs) a := C.GoString(addrs)
as := strings.Split(a, ",") as := strings.Split(a, ",")
servers := make([]pserver.Server, len(as)) servers := make([]pserver.Server, len(as))
...@@ -113,27 +101,27 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.client { ...@@ -113,27 +101,27 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.client {
} }
//export paddle_new_etcd_pserver_client //export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.client { func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client {
// TODO(helin): fault tolerant pserver client using etcd. // TODO(helin): fault tolerant pserver client using etcd.
panic("not implemented.") panic("not implemented.")
} }
//export paddle_pserver_client_release //export paddle_pserver_client_release
func paddle_pserver_client_release(client C.client) { func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove(client) remove(client)
} }
//export paddle_begin_init_params //export paddle_begin_init_params
func paddle_begin_init_params(client C.client) C.int { func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client) c := get(client)
if selected := c.BeginInitParams(); selected { if selected := c.BeginInitParams(); selected {
return 1 return 1
} }
return 0 return C.PSERVER_OK
} }
//export paddle_init_param //export paddle_init_param
func paddle_init_param(client C.client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int { func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int {
et := pserver.ElementType(param.element_type) et := pserver.ElementType(param.element_type)
name := C.GoString(param.name) name := C.GoString(param.name)
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len)) content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
...@@ -143,28 +131,38 @@ func paddle_init_param(client C.client, param C.paddle_parameter, param_config u ...@@ -143,28 +131,38 @@ func paddle_init_param(client C.client, param C.paddle_parameter, param_config u
} }
c := get(client) c := get(client)
err := c.InitParam(pc) err := c.InitParam(pc)
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Printf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name)
return C.PSERVER_OK
}
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
//export paddle_finish_init_params //export paddle_finish_init_params
func paddle_finish_init_params(client C.client) C.int { func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
c := get(client) c := get(client)
err := c.FinishInitParams() err := c.FinishInitParams()
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Println("parameters already initialized, treat paddle_finish_init_params as sucessful.")
return C.PSERVER_OK
}
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
//export paddle_send_grads //export paddle_send_grads
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int { func paddle_send_grads(client C.paddle_pserver_client, grads *C.paddle_gradient, total C.int) C.int {
var gs []pserver.Gradient var gs []pserver.Gradient
for i := 0; i < int(total); i++ { for i := 0; i < int(total); i++ {
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads)))) grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
...@@ -178,83 +176,81 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C ...@@ -178,83 +176,81 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C
err := c.SendGrads(gs) err := c.SendGrads(gs)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
//export paddle_get_params //export paddle_get_params
func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter, total C.int) C.int { func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, total C.int) C.int {
var ns []string var ns []string
for i := 0; i < int(total); i++ { for i := 0; i < int(total); i++ {
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names)))) param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
ns = append(ns, C.GoString(name)) ns = append(ns, C.GoString(param.name))
} }
c := get(client) c := get(client)
ps, err := c.GetParams(ns) ps, err := c.GetParams(ns)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
for i := 0; i < int(total); i++ { if len(ps) != len(ns) {
if i >= len(ps) { pn := make([]string, len(ps))
break for i, p := range ps {
pn[i] = p.Name
} }
log.Printf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR
}
for i := range ps {
if ns[i] != ps[i].Name {
pn := make([]string, len(ps))
for i, p := range ps {
pn[i] = p.Name
}
log.Printf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", "))
return C.PSERVER_ERROR
}
}
for i := 0; i < int(total); i++ {
p := ps[i] p := ps[i]
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
nameReady := false
contentAllocated := false
if unsafe.Pointer(param) == nullPtr { if unsafe.Pointer(param) == nullPtr {
param = (*C.paddle_parameter)(C.calloc(1, C.size_t(unsafe.Sizeof(*param)))) log.Println("must pre-allocate parameter.")
return C.PSERVER_ERROR
} else { } else {
if unsafe.Pointer(param.name) != nullPtr {
if n := C.GoString(param.name); n != p.Name {
log.Println("Warning: the pre-allocated parameter name does not match the parameter name, it will be freed.", n, p.Name)
C.free(unsafe.Pointer(param.name))
} else {
nameReady = true
}
}
if unsafe.Pointer(param.content) != nullPtr { if unsafe.Pointer(param.content) != nullPtr {
if int(param.content_len) == len(p.Content) { if int(param.content_len) != len(p.Content) {
contentAllocated = true log.Printf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content))
} else { return C.PSERVER_ERROR
log.Println("Warning: the pre-allocated content len does not match parameter content len, the pre-allocated content will be freed.", param.content_len, len(p.Content))
C.free(unsafe.Pointer(param.content))
} }
} }
} }
if !nameReady {
param.name = C.CString(p.Name)
}
if !contentAllocated {
param.content = (*C.uchar)(C.malloc(C.size_t(len(p.Content))))
}
C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
param.content_len = C.int(len(p.Content)) param.content_len = C.int(len(p.Content))
param.element_type = C.paddle_element_type(p.ElementType) param.element_type = C.paddle_element_type(p.ElementType)
} }
return 0 return C.PSERVER_OK
} }
//export paddle_save_model //export paddle_save_model
func paddle_save_model(client C.client, path *C.char) C.int { func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
p := C.GoString(path) p := C.GoString(path)
c := get(client) c := get(client)
err := c.Save(p) err := c.Save(p)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return -1 return C.PSERVER_ERROR
} }
return 0 return C.PSERVER_OK
} }
func main() {} // Required but ignored func main() {} // Required but ignored
...@@ -7,5 +7,7 @@ add_dependencies(main client) ...@@ -7,5 +7,7 @@ add_dependencies(main client)
if(APPLE) if(APPLE)
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
else()
set(CMAKE_EXE_LINKER_FLAGS "-pthread")
endif() endif()
target_link_libraries(main ${CMAKE_BINARY_DIR}/libclient.a) target_link_libraries(main ${CMAKE_BINARY_DIR}/libclient.a)
...@@ -2,67 +2,87 @@ ...@@ -2,67 +2,87 @@
#include "libclient.h" #include "libclient.h"
void fail() { // TODO(helin): Fix: gtest using cmake is not working, using this
// TODO(helin): fix: gtest using cmake is not working, using this // hacky way for now.
// hacky way for now. #define fail() \
printf("test failed.\n"); fprintf(stderr, "info: %s:%d: ", __FILE__, __LINE__); \
exit(-1); exit(-1);
void sendGrads(paddle_pserver_client c) {
unsigned char grad_a[2000] = {2};
unsigned char grad_b[3000] = {3};
paddle_gradient grads[2] = {
{"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000},
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000}};
if (paddle_send_grads(c, grads, 2)) {
fail();
}
}
void getParams(paddle_pserver_client c) {
paddle_parameter param_a;
paddle_parameter param_b;
char name_a[] = "param_a";
char name_b[] = "param_b";
// Must pre-allocate the prameter content before calling paddle_get_params.
unsigned char content_a[2000] = {};
unsigned char content_b[3000] = {};
param_a.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param_a.name = name_a;
param_a.content = content_a;
param_a.content_len = 2000;
param_b.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param_b.name = name_b;
param_b.content = content_b;
param_b.content_len = 3000;
paddle_parameter* params[2] = {&param_a, &param_b};
if (paddle_get_params(c, params, 2)) {
fail();
}
} }
int main() { int main() {
char addr[] = "localhost:3000"; char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr, 1); paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
retry: retry:
if (paddle_begin_init_params(c)) { if (paddle_begin_init_params(c)) {
paddle_parameter param; paddle_parameter param;
char name_a[] = "param_a"; char name_a[] = "param_a";
char name_b[] = "param_b"; char name_b[] = "param_b";
unsigned char content[] = {0x00, 0x11, 0x22}; unsigned char content_a[2000] = {1};
unsigned char content_b[3000] = {0};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_a; param.name = name_a;
param.content = content; param.content = content_a;
param.content_len = 3; param.content_len = 2000;
if (paddle_init_param(c, param, NULL, 0) != 0) { int error = paddle_init_param(c, param, NULL, 0);
if (error != 0) {
goto retry; goto retry;
} }
param.element_type = PADDLE_ELEMENT_TYPE_INT32;
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_b; param.name = name_b;
param.content = content; param.content = content_b;
param.content_len = 3; param.content_len = 3000;
if (paddle_init_param(c, param, NULL, 0) != 0) { error = paddle_init_param(c, param, NULL, 0);
if (error != 0) {
goto retry; goto retry;
} }
if (paddle_finish_init_params(c) != 0) {
error = paddle_finish_init_params(c);
if (error != 0) {
goto retry; goto retry;
} }
} else {
fail();
}
unsigned char content[] = {0x00, 0x11, 0x22};
paddle_gradient grads[2] = {
{"param_a", PADDLE_ELEMENT_TYPE_INT32, content, 3},
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}};
if (!paddle_send_grads(c, grads, 2)) {
fail();
}
paddle_parameter* params[2] = {NULL, NULL};
char* names[] = {"param_a", "param_b"};
if (!paddle_get_params(c, names, params, 2)) {
fail();
} }
// get parameters again by reusing the allocated parameter buffers. for (int i = 0; i < 100; i++) {
if (!paddle_get_params(c, names, params, 2)) { sendGrads(c);
fail(); getParams(c);
} }
paddle_release_param(params[0]); if (paddle_save_model(c, "/tmp/")) {
paddle_release_param(params[1]);
if (!paddle_save_model(c, "/tmp/")) {
fail(); fail();
} }
......
...@@ -117,7 +117,7 @@ func TestClientFull(t *testing.T) { ...@@ -117,7 +117,7 @@ func TestClientFull(t *testing.T) {
for i := range params { for i := range params {
if names[i] != params[i].Name { if names[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i]) t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name)
} }
} }
} }
...@@ -9,8 +9,10 @@ import ( ...@@ -9,8 +9,10 @@ import (
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
type ElementType int type ElementType int
var ErrAlreadyInitialized = errors.New("pserver already initialized") const (
var ErrUninitialized = errors.New("pserver not fully initialized") AlreadyInitialized = "pserver already initialized"
Uninitialized = "pserver not fully initialized"
)
// Supported element types // Supported element types
const ( const (
...@@ -59,7 +61,7 @@ func NewService() *Service { ...@@ -59,7 +61,7 @@ func NewService() *Service {
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return ErrAlreadyInitialized return errors.New(AlreadyInitialized)
default: default:
} }
...@@ -80,7 +82,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er ...@@ -80,7 +82,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
return ErrAlreadyInitialized return errors.New(AlreadyInitialized)
default: default:
} }
...@@ -94,7 +96,7 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error { ...@@ -94,7 +96,7 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
default: default:
return ErrUninitialized return errors.New(Uninitialized)
} }
s.mu.Lock() s.mu.Lock()
......
...@@ -16,7 +16,7 @@ func TestFull(t *testing.T) { ...@@ -16,7 +16,7 @@ func TestFull(t *testing.T) {
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
var dummy int var dummy int
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy) err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -25,7 +25,7 @@ func TestFull(t *testing.T) { ...@@ -25,7 +25,7 @@ func TestFull(t *testing.T) {
p1.Name = "param_b" p1.Name = "param_b"
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.ElementType = pserver.Float32 p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy) err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -81,7 +81,7 @@ func TestMultipleInit(t *testing.T) { ...@@ -81,7 +81,7 @@ func TestMultipleInit(t *testing.T) {
} }
err = s.FinishInitParams(0, &dummy) err = s.FinishInitParams(0, &dummy)
if err != pserver.ErrAlreadyInitialized { if err.Error() != pserver.AlreadyInitialized {
t.FailNow() t.FailNow()
} }
} }
...@@ -90,7 +90,7 @@ func TestUninitialized(t *testing.T) { ...@@ -90,7 +90,7 @@ func TestUninitialized(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
var dummy int var dummy int
err := s.SendGrad(pserver.Gradient{}, &dummy) err := s.SendGrad(pserver.Gradient{}, &dummy)
if err != pserver.ErrUninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.FailNow()
} }
} }
...@@ -135,7 +135,7 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -135,7 +135,7 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
var dummy int var dummy int
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy) err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册