提交 28476f5f 编写于 作者: Q qiaolongfei

fix the problem of paddle_send_grad

上级 6f1c91da
......@@ -164,10 +164,10 @@ func paddle_finish_init_params(client C.client) C.int {
}
//export paddle_send_grads
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int {
func paddle_send_grads(client C.client, grads **C.paddle_gradient, total C.int) C.int {
var gs []pserver.Gradient
for i := 0; i < int(total); i++ {
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
grad := *(**C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
et := pserver.ElementType(grad.element_type)
name := C.GoString(grad.name)
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
......
#include <stdio.h>
#include <stdlib.h>
#include "libpaddle_pserver_cclient.h"
......@@ -9,6 +10,21 @@ void fail() {
exit(-1);
}
void print_parameter(paddle_gradient* param) {
if (param == NULL) {
printf("param is NULL!!\n");
} else {
printf("==== parameter ====\n");
printf("name: %s\n", param->name);
printf("content_len: %d\n", param->content_len);
printf("content_type: %d\n", param->element_type);
for (int i = 0; i < param->content_len; ++i) {
printf("0x%x ", param->content[i]);
}
printf("\n");
}
}
int main() {
char addr[] = "localhost:3000";
client c = paddle_new_pserver_client(addr, 1);
......@@ -40,12 +56,27 @@ retry:
fail();
}
unsigned char content[] = {0x00, 0x11, 0x22};
paddle_gradient grads[2] = {
{"param_a", PADDLE_ELEMENT_TYPE_FLOAT32, content, 3},
{"param_b", PADDLE_ELEMENT_TYPE_INT32, content, 3}};
unsigned char content1[] = {0x12, 0x23, 0x34};
unsigned char content2[] = {0x45, 0x56, 0x67};
paddle_gradient** new_params =
(paddle_gradient**)malloc(sizeof(paddle_gradient*) * 2);
new_params[0] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
new_params[0]->name = "param_a";
new_params[0]->content = content1;
new_params[0]->content_len = 3;
new_params[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
if (paddle_send_grads(c, grads, 2) != 0) {
new_params[1] = (paddle_gradient*)malloc(sizeof(paddle_gradient));
new_params[1]->name = "param_b";
new_params[1]->content = content2;
new_params[1]->content_len = 3;
new_params[1]->element_type = PADDLE_ELEMENT_TYPE_INT32;
print_parameter(new_params[0]);
print_parameter(new_params[1]);
if (paddle_send_grads(c, new_params, 2) != 0) {
fail();
}
......@@ -55,6 +86,15 @@ retry:
fail();
}
print_parameter(params[0]);
print_parameter(params[1]);
/// change name of parameter.
char* names2[] = {"param_1", "param_2"};
if (paddle_get_params(c, names2, params, 2) == 0) {
fail();
}
// get parameters again by reusing the allocated parameter buffers.
if (paddle_get_params(c, names, params, 2) != 0) {
fail();
......
......@@ -22,7 +22,11 @@ DECLARE_string(save_dir);
namespace paddle {
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
const OptimizationConfig &config, const std::string pserverSpec)
: pserverSpec_(pserverSpec) {}
: parameterClient_(-1),
newParameters_(nullptr),
newGradients_(nullptr),
names_(nullptr),
pserverSpec_(pserverSpec) {}
void NewRemoteParameterUpdater::init(
const std::vector<ParameterPtr> &parameters) {
......@@ -72,7 +76,7 @@ void NewRemoteParameterUpdater::finishBatch(real cost) {
LOG(INFO) << "finishBatch in, cost: " << cost;
// send gradient to parameter server.
paddle_send_grads(parameterClient_, *newGradients_, parameterSize());
paddle_send_grads(parameterClient_, newGradients_, parameterSize());
// get the updated parameter from parameterClient.
paddle_get_params(parameterClient_, names_, newParameters_, parameterSize());
......
......@@ -32,6 +32,7 @@ public:
NewRemoteParameterUpdater(const OptimizationConfig& config,
const std::string pserverSpec);
~NewRemoteParameterUpdater() {
LOG(INFO) << "~NewRemoteParameterUpdater in";
releaseNewParameter(newParameters_);
releaseNewParameter(newGradients_);
if (parameterClient_ >= 0) paddle_pserver_client_release(parameterClient_);
......@@ -64,46 +65,47 @@ protected:
virtual void updateImpl(Parameter* para);
private:
int parameterSize() {
return (int)parameters_.size();
}
int parameterSize() { return (int)parameters_.size(); }
/**
* init parameter of paddle pserver cclient.
* @param new_params
* @param type
*/
paddle_parameter** initNewParameter(ParameterType type) {
paddle_parameter** new_params =
(paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize());
for (int i = 0; i < parameterSize(); ++i) {
new_params[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
memset(new_params[i], 0, sizeof(paddle_parameter));
}
/**
* init parameter of go paddle pserver cclient.
* @param new_params
* @param type
*/
paddle_parameter** initNewParameter(ParameterType type) {
paddle_parameter** new_params =
(paddle_parameter**)malloc(sizeof(paddle_parameter*) * parameterSize());
for (int i = 0; i < parameterSize(); ++i) {
new_params[i] = (paddle_parameter*)malloc(sizeof(paddle_parameter));
memset(new_params[i], 0, sizeof(paddle_parameter));
}
for (int i = 0; i < parameterSize(); ++i) {
ParameterPtr param = parameters_[i];
new_params[i]->content_len = 10;
new_params[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
new_params[i]->name = (char*)param->getName().c_str();
new_params[i]->content =
(unsigned char*)(param->getBuf(type).get()->getData());
new_params[i]->content_len = (int)param->getBuf(type).get()->getSize();
}
return new_params;
for (int i = 0; i < parameterSize(); ++i) {
ParameterPtr param = parameters_[i];
new_params[i]->content_len = 10;
new_params[i]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32;
new_params[i]->name = (char*)param->getName().c_str();
new_params[i]->content =
(unsigned char*)(param->getBuf(type).get()->getData());
new_params[i]->content_len = (int)param->getBuf(type).get()->getSize();
}
return new_params;
}
void releaseNewParameter(paddle_parameter** newParams) {
if (newParams != NULL) {
for (int i = 0; i < parameterSize(); ++i) {
paddle_release_param(newParams[i]);
void releaseNewParameter(paddle_parameter** newParams) {
if (newParams != nullptr) {
for (int i = 0; i < parameterSize(); ++i) {
auto param = newParams[i];
if (param != nullptr) {
paddle_release_param(param);
}
}
}
}
protected:
/// internal parameter client object for exchanging data with pserver
client parameterClient_ = -1;
client parameterClient_;
/// the parameters for new pserver client
paddle_parameter** newParameters_;
/// the gradinets for new pserver client
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册