提交 28476f5f 编写于 作者: Q qiaolongfei

fix the problem of paddle_send_grad

上级 6f1c91da
...@@ -164,10 +164,10 @@ func paddle_finish_init_params(client C.client) C.int { ...@@ -164,10 +164,10 @@ func paddle_finish_init_params(client C.client) C.int {
} }
//export paddle_send_grads //export paddle_send_grads
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int { func paddle_send_grads(client C.client, grads **C.paddle_gradient, total C.int) C.int {
var gs []pserver.Gradient var gs []pserver.Gradient
for i := 0; i < int(total); i++ { for i := 0; i < int(total); i++ {
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads)))) grad := *(**C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
et := pserver.ElementType(grad.element_type) et := pserver.ElementType(grad.element_type)
name := C.GoString(grad.name) name := C.GoString(grad.name)
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len)) content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
......
#include <stdio.h> #include <stdio.h>
#include <stdlib.h>
#include "libpaddle_pserver_cclient.h" #include "libpaddle_pserver_cclient.h"
...@@ -9,6 +10,21 @@ void fail() { ...@@ -9,6 +10,21 @@ void fail() {
exit(-1); exit(-1);
} }
void print_parameter(paddle_gradient* param) {
if (param == NULL) {
printf("param is NULL!!\n");
} else {
printf("==== parameter ====\n");
printf("name: %s\n", param->name);
printf("content_len: %d\n", param->content_len);
printf("content_type: %d\n", param->element_type);
for (int i = 0; i < param->content_len; ++i) {
printf("0x%x ", param->content[i]);
}
printf("\n");
}
}
int main() { int main() {
char addr[] = "localhost:3000"; char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr, 1); client c = paddle_new_pserver_client(addr, 1);
...@@ -40,12 +56,27 @@ retry: ...@@ -40,12 +56,27 @@ retry:
fail(); fail();
} }
unsigned char content[] = {0x00, 0x11, 0x22}; unsigned char content1[] = {0x12, 0x23, 0x34};
paddle_gradient grads[2] = { unsigned char content2[] = {0x45, 0x56, 0x67};
{"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3},
{"param_b", PADDLE_ELEMENT_TYPE_INT32, content, 3}}; paddle_gradient** new_params =
(paddle_gradient**)malloc(sizeof(paddle_gradient*) * 2);
new_params[0] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
new_params[0]->name = "param_a";
new_params[0]->content = content1;
new_params[0]->content_len = 3;
new_params[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
if (paddle_send_grads(c, grads, 2) != 0) { new_params[1] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
new_params[1]->name = "param_b";
new_params[1]->content = content2;
new_params[1]->content_len = 3;
new_params[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
print_parameter(new_params[0]);
print_parameter(new_params[1]);
if (paddle_send_grads(c, new_params, 2) != 0) {
fail(); fail();
} }
...@@ -55,6 +86,15 @@ retry: ...@@ -55,6 +86,15 @@ retry:
fail(); fail();
} }
print_parameter(params[0]);
print_parameter(params[1]);
/// change name of parameter.
char* names2[] = {"param_1", "param_2"};
if (paddle_get_params(c, names2, params, 2) == 0) {
fail();
}
// get parameters again by reusing the allocated parameter buffers. // get parameters again by reusing the allocated parameter buffers.
if (paddle_get_params(c, names, params, 2) != 0) { if (paddle_get_params(c, names, params, 2) != 0) {
fail(); fail();
......
...@@ -22,7 +22,11 @@ DECLARE_string(save_dir); ...@@ -22,7 +22,11 @@ DECLARE_string(save_dir);
namespace paddle { namespace paddle {
NewRemoteParameterUpdater::NewRemoteParameterUpdater( NewRemoteParameterUpdater::NewRemoteParameterUpdater(
const OptimizationConfig &config, const std::string pserverSpec) const OptimizationConfig &config, const std::string pserverSpec)
: pserverSpec_(pserverSpec) {} : parameterClient_(-1),
newParameters_(nullptr),
newGradients_(nullptr),
names_(nullptr),
pserverSpec_(pserverSpec) {}
void NewRemoteParameterUpdater::init( void NewRemoteParameterUpdater::init(
const std::vector<ParameterPtr> &parameters) { const std::vector<ParameterPtr> &parameters) {
...@@ -72,7 +76,7 @@ void NewRemoteParameterUpdater::finishBatch(real cost) { ...@@ -72,7 +76,7 @@ void NewRemoteParameterUpdater::finishBatch(real cost) {
LOG(INFO) << "finishBatch in, cost: " << cost; LOG(INFO) << "finishBatch in, cost: " << cost;
// send gradient to parameter server. // send gradient to parameter server.
paddle_send_grads(parameterClient_, *newGradients_, parameterSize()); paddle_send_grads(parameterClient_, newGradients_, parameterSize());
// get the updated parameter from parameterClient. // get the updated parameter from parameterClient.
paddle_get_params(parameterClient_, names_, newParameters_, parameterSize()); paddle_get_params(parameterClient_, names_, newParameters_, parameterSize());
......
...@@ -32,6 +32,7 @@ public: ...@@ -32,6 +32,7 @@ public:
NewRemoteParameterUpdater(const OptimizationConfig& config, NewRemoteParameterUpdater(const OptimizationConfig& config,
const std::string pserverSpec); const std::string pserverSpec);
~NewRemoteParameterUpdater() { ~NewRemoteParameterUpdater() {
LOG(INFO) << "~NewRemoteParameterUpdater in";
releaseNewParameter(newParameters_); releaseNewParameter(newParameters_);
releaseNewParameter(newGradients_); releaseNewParameter(newGradients_);
if (parameterClient_ >= 0) paddle_pserver_client_release(parameterClient_); if (parameterClient_ >= 0) paddle_pserver_client_release(parameterClient_);
...@@ -64,46 +65,47 @@ protected: ...@@ -64,46 +65,47 @@ protected:
virtual void updateImpl(Parameter* para); virtual void updateImpl(Parameter* para);
private: private:
int parameterSize() { int parameterSize() { return (int)parameters_.size(); }
return (int)parameters_.size();
}
/** /**
* init parameter of paddle pserver cclient. * init parameter of go paddle pserver cclient.
* @param new_params * @param new_params
* @param type * @param type
*/ */
paddle_parameter** initNewParameter(ParameterType type) { paddle_parameter** initNewParameter(ParameterType type) {
paddle_parameter** new_params = paddle_parameter** new_params =
(paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize()); (paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize());
for (int i = 0; i < parameterSize(); ++i) { for (int i = 0; i < parameterSize(); ++i) {
new_params[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter)); new_params[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
memset(new_params[i], 0, sizeof(paddle_parameter)); memset(new_params[i], 0, sizeof(paddle_parameter));
} }
for (int i = 0; i < parameterSize(); ++i) { for (int i = 0; i < parameterSize(); ++i) {
ParameterPtr param = parameters_[i]; ParameterPtr param = parameters_[i];
new_params[i]->content_len = 10; new_params[i]->content_len = 10;
new_params[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32; new_params[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
new_params[i]->name = (char*)param->getName().c_str(); new_params[i]->name = (char*)param->getName().c_str();
new_params[i]->content = new_params[i]->content =
(unsigned char*)(param->getBuf(type).get()->getData()); (unsigned char*)(param->getBuf(type).get()->getData());
new_params[i]->content_len = (int)param->getBuf(type).get()->getSize(); new_params[i]->content_len = (int)param->getBuf(type).get()->getSize();
}
return new_params;
} }
return new_params;
}
void releaseNewParameter(paddle_parameter** newParams) { void releaseNewParameter(paddle_parameter** newParams) {
if (newParams != NULL) { if (newParams != nullptr) {
for (int i = 0; i < parameterSize(); ++i) { for (int i = 0; i < parameterSize(); ++i) {
paddle_release_param(newParams[i]); auto param = newParams[i];
if (param != nullptr) {
paddle_release_param(param);
} }
} }
} }
}
protected: protected:
/// internal parameter client object for exchanging data with pserver /// internal parameter client object for exchanging data with pserver
client parameterClient_ = -1; client parameterClient_;
/// the parameters for new pserver client /// the parameters for new pserver client
paddle_parameter** newParameters_; paddle_parameter** newParameters_;
/// the gradinets for new pserver client /// the gradinets for new pserver client
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册