提交 c44f5dd8 编写于 作者: Q qiaolongfei

add simple updater, this version can train a model

上级 966bf9ae
...@@ -42,7 +42,6 @@ import ( ...@@ -42,7 +42,6 @@ import (
"strings" "strings"
"sync" "sync"
"unsafe" "unsafe"
"fmt"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
......
...@@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.0) ...@@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.0)
add_executable(main main.c) add_executable(main main.c)
add_dependencies(main paddle_pserver_cclient) add_dependencies(main paddle_pserver_cclient)
add_executable(test_cclient test_cclient.c)
add_dependencies(test_cclient paddle_pserver_cclient)
if(APPLE) if(APPLE)
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
...@@ -10,7 +12,9 @@ endif() ...@@ -10,7 +12,9 @@ endif()
if(PROJ_ROOT) if(PROJ_ROOT)
include_directories(${CMAKE_BINARY_DIR}/go/pserver/cclient/) include_directories(${CMAKE_BINARY_DIR}/go/pserver/cclient/)
target_link_libraries(main ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.a pthread) target_link_libraries(main ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.a pthread)
target_link_libraries(test_cclient ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.a pthread)
else(PROJ_ROOT) else(PROJ_ROOT)
include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR})
target_link_libraries(main ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread) target_link_libraries(main ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread)
target_link_libraries(test_cclient ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread)
endif(PROJ_ROOT) endif(PROJ_ROOT)
#include <stdio.h> #include <stdio.h>
#include <stdlib.h>
#include "libpaddle_pserver_cclient.h" #include "libpaddle_pserver_cclient.h"
...@@ -10,21 +9,6 @@ void fail() { ...@@ -10,21 +9,6 @@ void fail() {
exit(-1); exit(-1);
} }
void print_parameter(paddle_gradient* param) {
if (param == NULL) {
printf("param is NULL!!\n");
} else {
printf("==== parameter ====\n");
printf("name: %s\n", param->name);
printf("content_len: %d\n", param->content_len);
printf("content_type: %d\n", param->element_type);
for (int i = 0; i < param->content_len; ++i) {
printf("0x%x ", param->content[i]);
}
printf("\n\n");
}
}
int main() { int main() {
char addr[] = "localhost:3000"; char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr, 1); client c = paddle_new_pserver_client(addr, 1);
...@@ -33,23 +17,21 @@ retry: ...@@ -33,23 +17,21 @@ retry:
paddle_parameter param; paddle_parameter param;
char name_a[] = "param_a"; char name_a[] = "param_a";
char name_b[] = "param_b"; char name_b[] = "param_b";
unsigned char content1[] = {0x01, 0x02, 0x03}; unsigned char content[] = {0x00, 0x11, 0x22};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = name_a; param.name = name_a;
param.content = content1; param.content = content;
param.content_len = 3; param.content_len = 3;
if (paddle_init_param(c, param, NULL, 0) != 0) { if (paddle_init_param(c, param, NULL, 0) != 0) {
goto retry; goto retry;
} }
unsigned char content2[] = {0x04, 0x05, 0x06};
param.element_type = PADDLE_ELEMENT_TYPE_INT32; param.element_type = PADDLE_ELEMENT_TYPE_INT32;
param.name = name_b; param.name = name_b;
param.content = content2; param.content = content;
param.content_len = 3; param.content_len = 3;
if (paddle_init_param(c, param, NULL, 0) != 0) { if (paddle_init_param(c, param, NULL, 0) != 0) {
goto retry; goto retry;
} }
if (paddle_finish_init_params(c) != 0) { if (paddle_finish_init_params(c) != 0) {
goto retry; goto retry;
} }
...@@ -57,27 +39,22 @@ retry: ...@@ -57,27 +39,22 @@ retry:
fail(); fail();
} }
unsigned char content1[] = {0x12, 0x23, 0x34}; unsigned char content[] = {0x00, 0x11, 0x22};
unsigned char content2[] = {0x45, 0x56, 0x67}; paddle_gradient** grads =
paddle_gradient** new_params =
(paddle_gradient**)malloc(sizeof(paddle_gradient*) * 2); (paddle_gradient**)malloc(sizeof(paddle_gradient*) * 2);
new_params[0] = (paddle_gradient*)malloc(sizeof(paddle_gradient)); grads[0] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
new_params[0]->name = "param_a"; grads[0]->name = "param_a";
new_params[0]->content = content1; grads[0]->content = content;
new_params[0]->content_len = 3; grads[0]->content_len = 3;
new_params[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32; grads[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
new_params[1] = (paddle_gradient*)malloc(sizeof(paddle_gradient)); grads[1] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
new_params[1]->name = "param_b"; grads[1]->name = "param_b";
new_params[1]->content = content2; grads[1]->content = content;
new_params[1]->content_len = 3; grads[1]->content_len = 3;
new_params[1]->element_type = PADDLE_ELEMENT_TYPE_INT32; grads[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
print_parameter(new_params[0]); if (paddle_send_grads(c, grads, 2) != 0) {
print_parameter(new_params[1]);
if (paddle_send_grads(c, new_params, 2) != 0) {
fail(); fail();
} }
...@@ -87,15 +64,6 @@ retry: ...@@ -87,15 +64,6 @@ retry:
fail(); fail();
} }
print_parameter(params[0]);
print_parameter(params[1]);
/// change name of parameter.
char* names2[] = {"param_1", "param_2"};
if (paddle_get_params(c, names2, params, 2) == 0) {
fail();
}
// get parameters again by reusing the allocated parameter buffers. // get parameters again by reusing the allocated parameter buffers.
if (paddle_get_params(c, names, params, 2) != 0) { if (paddle_get_params(c, names, params, 2) != 0) {
fail(); fail();
...@@ -109,5 +77,6 @@ retry: ...@@ -109,5 +77,6 @@ retry:
} }
printf("test success!\n"); printf("test success!\n");
return 0; return 0;
} }
#include <stdio.h>
#include <stdlib.h>
#include "libpaddle_pserver_cclient.h"
typedef float real;
void fail() {
// TODO(helin): fix: gtest using cmake is not working, using this
// hacky way for now.
printf("test failed.\n");
exit(-1);
}
void print_parameter(paddle_gradient* param) {
if (param == NULL) {
printf("param is NULL!!\n");
} else {
printf("==== parameter ====\n");
printf("name: %s\n", param->name);
printf("content_len: %d\n", param->content_len);
printf("content_type: %d\n", param->element_type);
for (int i = 0; i < param->content_len/sizeof(real); ++i) {
printf("%f ", ((float *)param->content)[i]);
}
printf("\n\n");
}
}
int main() {
char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr, 1);
char* names[] = {"param_a", "param_b"};
retry:
if (paddle_begin_init_params(c)) {
paddle_parameter param;
real param_content1[] = {0.1, 0.2, 0.3};
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
param.name = names[0];
param.content = (unsigned char*)param_content1;
param.content_len = 3 * sizeof(real);
if (paddle_init_param(c, param, NULL, 0) != 0) {
goto retry;
}
real param_content2[] = {0.4, 0.5, 0.6};
param.element_type = PADDLE_ELEMENT_TYPE_INT32;
param.name = names[1];
param.content = (unsigned char*)param_content2;
param.content_len = 3 * sizeof(real);
if (paddle_init_param(c, param, NULL, 0) != 0) {
goto retry;
}
if (paddle_finish_init_params(c) != 0) {
goto retry;
}
} else {
fail();
}
printf("get initialized parameters from pserver:\n");
paddle_parameter* param_ptrs[2] = {NULL, NULL};
if (paddle_get_params(c, names, param_ptrs, 2) != 0) {
fail();
}
print_parameter(param_ptrs[0]);
print_parameter(param_ptrs[1]);
printf("send gradient to pserver:\n");
real gradient_content1[] = {0.01, 0.02, 0.03};
real gradinet_content2[] = {0.04, 0.05, 0.06};
paddle_gradient** grads =
(paddle_gradient**)malloc(sizeof(paddle_gradient*) * 2);
grads[0] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
grads[0]->name = names[0];
grads[0]->content = (unsigned char*)gradient_content1;
grads[0]->content_len = 3 * sizeof(real);
grads[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
grads[1] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
grads[1]->name = names[1];
grads[1]->content = (unsigned char*)gradinet_content2;
grads[1]->content_len = 3 * sizeof(real);
grads[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
print_parameter(grads[0]);
print_parameter(grads[1]);
if (paddle_send_grads(c, grads, 2) != 0) {
fail();
}
printf("get updated parameters from pserver:\n");
// get parameters again by reusing the allocated parameter buffers.
if (paddle_get_params(c, names, param_ptrs, 2) != 0) {
fail();
}
print_parameter(param_ptrs[0]);
print_parameter(param_ptrs[1]);
paddle_release_param(param_ptrs[0]);
paddle_release_param(param_ptrs[1]);
if (paddle_save_model(c, "/tmp/") != 0) {
fail();
}
printf("test success!\n");
return 0;
}
...@@ -32,7 +32,13 @@ int update_SGD(void* optimizer, ...@@ -32,7 +32,13 @@ int update_SGD(void* optimizer,
const void* gradient, const void* gradient,
int num_bytes) { int num_bytes) {
SGD_optimizer* o = (SGD_optimizer*)optimizer; SGD_optimizer* o = (SGD_optimizer*)optimizer;
// TODO // TODO(a simple SGD implement)
float* parameter = (float *)buffer;
float* grad = (float *)gradient;
for(int i = 0; i < num_bytes/sizeof(float); ++i) {
parameter[i] -= o->learning_rate * grad[i];
}
return 0; return 0;
} }
......
...@@ -73,8 +73,6 @@ void NewRemoteParameterUpdater::init( ...@@ -73,8 +73,6 @@ void NewRemoteParameterUpdater::init(
void NewRemoteParameterUpdater::updateImpl(Parameter *para) {} void NewRemoteParameterUpdater::updateImpl(Parameter *para) {}
void NewRemoteParameterUpdater::finishBatch(real cost) { void NewRemoteParameterUpdater::finishBatch(real cost) {
LOG(INFO) << "finishBatch in, cost: " << cost;
// send gradient to parameter server. // send gradient to parameter server.
paddle_send_grads(parameterClient_, newGradients_, parameterSize()); paddle_send_grads(parameterClient_, newGradients_, parameterSize());
// get the updated parameter from parameterClient. // get the updated parameter from parameterClient.
......
...@@ -87,7 +87,7 @@ private: ...@@ -87,7 +87,7 @@ private:
new_params[i]->name = (char*)param->getName().c_str(); new_params[i]->name = (char*)param->getName().c_str();
new_params[i]->content = new_params[i]->content =
(unsigned char*)(param->getBuf(type).get()->getData()); (unsigned char*)(param->getBuf(type).get()->getData());
new_params[i]->content_len = (int)param->getBuf(type).get()->getSize(); new_params[i]->content_len = (int)param->getBuf(type).get()->getSize() * sizeof(real);
} }
return new_params; return new_params;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册