提交 1b51d8de 编写于 作者: R rensilin

fix

Change-Id: I07b6ad389aa01a759500f0d580db8e187c4bc368
上级 540c5dc0
......@@ -78,43 +78,6 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num,
return 0;
}
int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) {
if (!_need_gradient) {
return 0;
}
size_t data_buffer_idx = 0;
std::vector<paddle::ps::Region> regions;
for (auto& variable : _x_variables) {
auto* tensor = scope->Var(variable.gradient_name)->
GetMutable<paddle::framework::LoDTensor>();
auto* grad_data = tensor->mutable_data<float>(_trainer_context->cpu_place);
regions.emplace_back(grad_data, variable.dim);
}
auto* ps_client = _trainer_context->pslib->ps_client();
auto push_status = ps_client->push_dense(regions.data(), regions.size(), _table_id);
//push_status.get();
if (!FLAGS_feed_trainer_debug_dense_name.empty()) {
std::stringstream ssm;
for (auto& variable : _x_variables) {
ssm.str("");
if (variable.name != FLAGS_feed_trainer_debug_dense_name) {
continue;
}
auto& tensor = scope->Var(variable.gradient_name)->
Get<paddle::framework::LoDTensor>();
const auto* var_data = tensor.data<float>();
for (size_t data_idx = 0; data_idx < variable.dim; ++data_idx) {
if (data_idx > 0)
ssm << ",";
ssm << var_data[data_idx];
}
VLOG(2) << "[DEBUG]push_dense: " << ssm.str();
}
}
return 0;
}
int32_t DenseInputAccessor::collect_persistables(paddle::framework::Scope* scope) {
// 首次同步pull,之后异步pull
if (_data_buffer == nullptr) {
......@@ -178,6 +141,43 @@ int32_t DenseInputAccessor::collect_persistables_name(std::vector<std::string>&
return 0;
}
int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) {
if (!_need_gradient) {
return 0;
}
size_t data_buffer_idx = 0;
std::vector<paddle::ps::Region> regions;
for (auto& variable : _x_variables) {
auto* tensor = scope->Var(variable.gradient_name)->
GetMutable<paddle::framework::LoDTensor>();
auto* grad_data = tensor->mutable_data<float>(_trainer_context->cpu_place);
regions.emplace_back(grad_data, variable.dim);
}
auto* ps_client = _trainer_context->pslib->ps_client();
auto push_status = ps_client->push_dense(regions.data(), regions.size(), _table_id);
//push_status.get();
if (!FLAGS_feed_trainer_debug_dense_name.empty()) {
std::stringstream ssm;
for (auto& variable : _x_variables) {
ssm.str("");
if (variable.name != FLAGS_feed_trainer_debug_dense_name) {
continue;
}
auto& tensor = scope->Var(variable.gradient_name)->
Get<paddle::framework::LoDTensor>();
const auto* var_data = tensor.data<float>();
for (size_t data_idx = 0; data_idx < variable.dim; ++data_idx) {
if (data_idx > 0)
ssm << ",";
ssm << var_data[data_idx];
}
VLOG(2) << "[DEBUG]push_dense: " << ssm.str();
}
}
return 0;
}
int32_t EbdVariableInputAccessor::forward(SampleInstance* samples, size_t num,
paddle::framework::Scope* scope) {
CHECK(_x_variables.size() == 1);
......
......@@ -2,6 +2,8 @@
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/monitor/monitor.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace custom_trainer {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册