提交 ce3408a7 编写于 作者: H Helin Wang

modify pserver client C API, create better test

Please refer to the change design doc for what in API have changed.
上级 d3389765
......@@ -74,14 +74,25 @@ typedef enum {
typedef struct {
char* name;
paddle_element_type element_type;
void* content;
unsigned char* content;
int content_len;
} paddle_parameter, paddle_gradient;
typedef struct paddle_pserver_client paddle_pserver_client;
typedef int paddle_pserver_client;
paddle_pserver_client* paddle_new_pserver_client();
void paddle_pserver_client_release(paddle_pserver_client* client);
/**
* @brief creates a pserver client that talks to etcd for coordination.
*/
paddle_pserver_client paddle_new_etcd_pserver_client(char* etcd_addr);
/**
* @brief creates a pserver client given pserver addresses.
*
* @param pserver_addrs comma-separated pserver addresses.
* @param selected if current pserver client is selected to initialize all parameter servers.
*/
paddle_pserver_client paddle_new_pserver_client(char* pserver_addrs, int selected);
void paddle_pserver_client_release(paddle_pserver_client c);
/**
* @brief paddle_begin_init_params begins to initialize parameters on
......@@ -95,7 +106,7 @@ void paddle_pserver_client_release(paddle_pserver_client* client);
* @return 1 if the trainer is selected to initialize parameter
* servers, otherwise 0.
*/
int paddle_begin_init_params(paddle_pserver_client* client);
int paddle_begin_init_params(paddle_pserver_client client);
/**
* @brief paddle_init_param initializes the parameter on parameter
......@@ -109,7 +120,7 @@ int paddle_begin_init_params(paddle_pserver_client* client);
* @paddle_begin_init_param). Or simply exit the program and wait for
* the cluster management system to restart the trainer.
*/
int paddle_init_param(paddle_pserver_client* client, paddle_parameter param, const unsigned char* param_config_proto, int config_len);
int paddle_init_param(paddle_pserver_client client, paddle_parameter param, const unsigned char* param_config_proto, int config_len);
/**
* @brief paddle_finish_init_params tells parameter servers client has
......@@ -120,7 +131,7 @@ int paddle_init_param(paddle_pserver_client* client, paddle_parameter param, con
* @paddle_begin_init_param). Or simply exit the program and wait for
* the cluster management system to restart the trainer.
*/
int paddle_finish_init_params(paddle_pserver_client* client);
int paddle_finish_init_params(paddle_pserver_client client);
/**
* @brief paddle_send_grads sends gradients to parameter servers for
......@@ -131,7 +142,7 @@ int paddle_finish_init_params(paddle_pserver_client* client);
* @param learning_rate the learning rate for the gradients.
* @return 0 if successful, otherwise -1.
*/
int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grads, int len);
int paddle_send_grads(paddle_pserver_client client, const paddle_gradient* grads, int len);
/**
* @brief paddle_get_params gets parameters from parameter servers.
......@@ -139,13 +150,15 @@ int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grad
* paddle_get_params will block until parameters are initialized on
* the parameter servers.
*
* @param names the array of names of the parameters to get.
* @param dst the destination array of parameters to save to.
* @param dst the destination array of parameter pointers to save to.
* The parameter pointer must be pre-popullated with required parameter name,
* and the content of parameter must be pre-allocated of the size of required
* parameter on pserver.
* @param len the length of the names array and the paddle_parameter
* array.
* @return 0 if successful, otherwise -1.
*/
int paddle_get_params(paddle_pserver_client* client, const char** names, paddle_parameter* dst, int len);
int paddle_get_params(paddle_pserver_client client, paddle_parameter** dst, int len);
/**
* @brief paddle_save_model indicates parameters to save the parameter
......@@ -154,5 +167,5 @@ int paddle_get_params(paddle_pserver_client* client, const char** names, paddle_
* @param path the path to save parameters.
* @return 0 if successful, otherwise -1.
*/
int paddle_save_model(paddle_pserver_client* client, const char* path);
int paddle_save_model(paddle_pserver_client client, const char* path);
```
......@@ -19,21 +19,7 @@ typedef struct {
int content_len;
} paddle_parameter, paddle_gradient;
static inline void paddle_release_param(paddle_parameter* param) {
if (param != NULL) {
if (param->name != NULL) {
free(param->name);
}
if (param->content != NULL) {
free(param->content);
}
free(param);
}
}
typedef int client;
typedef int paddle_pserver_client;
*/
import "C"
......@@ -48,10 +34,10 @@ import (
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.client]*pserver.Client)
var curHandle C.client
var handleMap = make(map[C.paddle_pserver_client]*pserver.Client)
var curHandle C.paddle_pserver_client
func add(c *pserver.Client) C.client {
func add(c *pserver.Client) C.paddle_pserver_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
......@@ -60,13 +46,13 @@ func add(c *pserver.Client) C.client {
return client
}
func get(client C.client) *pserver.Client {
func get(client C.paddle_pserver_client) *pserver.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}
func remove(client C.client) *pserver.Client {
func remove(client C.paddle_pserver_client) *pserver.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
......@@ -100,7 +86,7 @@ func (l lister) List() []pserver.Server {
}
//export paddle_new_pserver_client
func paddle_new_pserver_client(addrs *C.char, selected int) C.client {
func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
a := C.GoString(addrs)
as := strings.Split(a, ",")
servers := make([]pserver.Server, len(as))
......@@ -113,18 +99,18 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.client {
}
//export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.client {
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client {
// TODO(helin): fault tolerant pserver client using etcd.
panic("not implemented.")
}
//export paddle_pserver_client_release
func paddle_pserver_client_release(client C.client) {
func paddle_pserver_client_release(client C.paddle_pserver_client) {
remove(client)
}
//export paddle_begin_init_params
func paddle_begin_init_params(client C.client) C.int {
func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client)
if selected := c.BeginInitParams(); selected {
return 1
......@@ -133,7 +119,7 @@ func paddle_begin_init_params(client C.client) C.int {
}
//export paddle_init_param
func paddle_init_param(client C.client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int {
func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int {
et := pserver.ElementType(param.element_type)
name := C.GoString(param.name)
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
......@@ -143,7 +129,12 @@ func paddle_init_param(client C.client, param C.paddle_parameter, param_config u
}
c := get(client)
err := c.InitParam(pc)
if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Println("parameter", name, "already initialized, treat paddle_init_param as sucessful.")
return 0
}
log.Println(err)
return -1
}
......@@ -152,10 +143,15 @@ func paddle_init_param(client C.client, param C.paddle_parameter, param_config u
}
//export paddle_finish_init_params
func paddle_finish_init_params(client C.client) C.int {
func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
c := get(client)
err := c.FinishInitParams()
if err != nil {
if err.Error() == pserver.AlreadyInitialized {
log.Println("parameters already initialized, treat paddle_finish_init_params as sucessful.")
return 0
}
log.Println(err)
return -1
}
......@@ -164,7 +160,7 @@ func paddle_finish_init_params(client C.client) C.int {
}
//export paddle_send_grads
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int {
func paddle_send_grads(client C.paddle_pserver_client, grads *C.paddle_gradient, total C.int) C.int {
var gs []pserver.Gradient
for i := 0; i < int(total); i++ {
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
......@@ -185,11 +181,11 @@ func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C
}
//export paddle_get_params
func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter, total C.int) C.int {
func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, total C.int) C.int {
var ns []string
for i := 0; i < int(total); i++ {
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names))))
ns = append(ns, C.GoString(name))
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
ns = append(ns, C.GoString(param.name))
}
c := get(client)
ps, err := c.GetParams(ns)
......@@ -198,44 +194,32 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter
return -1
}
for i := 0; i < int(total); i++ {
if i >= len(ps) {
break
if len(ps) != len(ns) {
return -1
}
for i := range ps {
if ns[i] != ps[i].Name {
return -1
}
}
for i := 0; i < int(total); i++ {
p := ps[i]
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
nameReady := false
contentAllocated := false
if unsafe.Pointer(param) == nullPtr {
param = (*C.paddle_parameter)(C.calloc(1, C.size_t(unsafe.Sizeof(*param))))
log.Println("Error: must pre-allocate parameter.")
return -1
} else {
if unsafe.Pointer(param.name) != nullPtr {
if n := C.GoString(param.name); n != p.Name {
log.Println("Warning: the pre-allocated parameter name does not match the parameter name, it will be freed.", n, p.Name)
C.free(unsafe.Pointer(param.name))
} else {
nameReady = true
}
}
if unsafe.Pointer(param.content) != nullPtr {
if int(param.content_len) == len(p.Content) {
contentAllocated = true
} else {
log.Println("Warning: the pre-allocated content len does not match parameter content len, the pre-allocated content will be freed.", param.content_len, len(p.Content))
C.free(unsafe.Pointer(param.content))
if int(param.content_len) != len(p.Content) {
log.Println("Error: the pre-allocated content len does not match parameter content len.", param.content_len, len(p.Content))
return -1
}
}
}
if !nameReady {
param.name = C.CString(p.Name)
}
if !contentAllocated {
param.content = (*C.uchar)(C.malloc(C.size_t(len(p.Content))))
}
C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
param.content_len = C.int(len(p.Content))
param.element_type = C.paddle_element_type(p.ElementType)
......@@ -245,7 +229,7 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter
}
//export paddle_save_model
func paddle_save_model(client C.client, path *C.char) C.int {
func paddle_save_model(client C.paddle_pserver_client, path *C.char) C.int {
p := C.GoString(path)
c := get(client)
err := c.Save(p)
......
......@@ -7,5 +7,7 @@ add_dependencies(main client)
if(APPLE)
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
else()
set(CMAKE_EXE_LINKER_FLAGS "-pthread")
endif()
target_link_libraries(main ${CMAKE_BINARY_DIR}/libclient.a)
......@@ -2,67 +2,87 @@
#include "libclient.h"
void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now.
printf("test failed.\n");
// TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now.
#define fail() \
fprintf(stderr, "info: %s:%d: ", __FILE__, __LINE__); \
exit(-1);
void sendGrads(paddle_pserver_client c) {
unsigned char grad_a[2000] = {2};
unsigned char grad_b[3000] = {3};
paddle_gradient grads[2] = {
{"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, grad_a, 2000},
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, grad_b, 3000}};
if (paddle_send_grads(c, grads, 2)) {
fail();
}
}
void getParams(paddle_pserver_client c) {
paddle_parameter param_a;
paddle_parameter param_b;
char name_a[] = "param_a";
char name_b[] = "param_b";
// must pre-allocate the content for parameter to receive.
unsigned char content_a[2000] = {};
unsigned char content_b[3000] = {};
param_a.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param_a.name = name_a;
param_a.content = content_a;
param_a.content_len = 2000;
param_b.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param_b.name = name_b;
param_b.content = content_b;
param_b.content_len = 3000;
paddle_parameter* params[2] = {&param_a, &param_b};
if (paddle_get_params(c, params, 2)) {
fail();
}
}
int main() {
char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr, 1);
paddle_pserver_client c = paddle_new_pserver_client(addr, 1);
retry:
if (paddle_begin_init_params(c)) {
paddle_parameter param;
char name_a[] = "param_a";
char name_b[] = "param_b";
unsigned char content[] = {0x00, 0x11, 0x22};
unsigned char content_a[2000] = {1};
unsigned char content_b[3000] = {0};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_a;
param.content = content;
param.content_len = 3;
if (paddle_init_param(c, param, NULL, 0) != 0) {
param.content = content_a;
param.content_len = 2000;
int error = paddle_init_param(c, param, NULL, 0);
if (error != 0) {
goto retry;
}
param.element_type = PADDLE_ELEMENT_TYPE_INT32;
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_b;
param.content = content;
param.content_len = 3;
if (paddle_init_param(c, param, NULL, 0) != 0) {
param.content = content_b;
param.content_len = 3000;
error = paddle_init_param(c, param, NULL, 0);
if (error != 0) {
goto retry;
}
if (paddle_finish_init_params(c) != 0) {
error = paddle_finish_init_params(c);
if (error != 0) {
goto retry;
}
} else {
fail();
}
unsigned char content[] = {0x00, 0x11, 0x22};
paddle_gradient grads[2] = {
{"param_a", PADDLE_ELEMENT_TYPE_INT32, content, 3},
{"param_b", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3}};
if (!paddle_send_grads(c, grads, 2)) {
fail();
}
paddle_parameter* params[2] = {NULL, NULL};
char* names[] = {"param_a", "param_b"};
if (!paddle_get_params(c, names, params, 2)) {
fail();
}
// get parameters again by reusing the allocated parameter buffers.
if (!paddle_get_params(c, names, params, 2)) {
fail();
for (int i = 0; i < 100; i++) {
sendGrads(c);
getParams(c);
}
paddle_release_param(params[0]);
paddle_release_param(params[1]);
if (!paddle_save_model(c, "/tmp/")) {
if (paddle_save_model(c, "/tmp/")) {
fail();
}
......
......@@ -117,7 +117,7 @@ func TestClientFull(t *testing.T) {
for i := range params {
if names[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i])
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i].Name)
}
}
}
......@@ -9,8 +9,10 @@ import (
// ElementType is the type of elements of a Parameter.
type ElementType int
var ErrAlreadyInitialized = errors.New("pserver already initialized")
var ErrUninitialized = errors.New("pserver not fully initialized")
const (
AlreadyInitialized = "pserver already initialized"
Uninitialized = "pserver not fully initialized"
)
// Supported element types
const (
......@@ -59,7 +61,7 @@ func NewService() *Service {
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select {
case <-s.initialized:
return ErrAlreadyInitialized
return errors.New(AlreadyInitialized)
default:
}
......@@ -80,7 +82,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
select {
case <-s.initialized:
return ErrAlreadyInitialized
return errors.New(AlreadyInitialized)
default:
}
......@@ -94,7 +96,7 @@ func (s *Service) SendGrad(g Gradient, dummy *int) error {
select {
case <-s.initialized:
default:
return ErrUninitialized
return errors.New(Uninitialized)
}
s.mu.Lock()
......
......@@ -16,7 +16,7 @@ func TestFull(t *testing.T) {
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
var dummy int
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy)
if err != nil {
t.FailNow()
}
......@@ -25,7 +25,7 @@ func TestFull(t *testing.T) {
p1.Name = "param_b"
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy)
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, &dummy)
if err != nil {
t.FailNow()
}
......@@ -81,7 +81,7 @@ func TestMultipleInit(t *testing.T) {
}
err = s.FinishInitParams(0, &dummy)
if err != pserver.ErrAlreadyInitialized {
if err.Error() != pserver.AlreadyInitialized {
t.FailNow()
}
}
......@@ -90,7 +90,7 @@ func TestUninitialized(t *testing.T) {
s := pserver.NewService()
var dummy int
err := s.SendGrad(pserver.Gradient{}, &dummy)
if err != pserver.ErrUninitialized {
if err.Error() != pserver.Uninitialized {
t.FailNow()
}
}
......@@ -135,7 +135,7 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
var dummy int
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy)
if err != nil {
t.FailNow()
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册