From 1b51d8de7f3710602beb03bbf2562110dae178f3 Mon Sep 17 00:00:00 2001 From: rensilin Date: Wed, 4 Sep 2019 16:50:15 +0800 Subject: [PATCH] fix Change-Id: I07b6ad389aa01a759500f0d580db8e187c4bc368 --- .../feed/accessor/dense_input_accessor.cc | 74 +++++++++---------- .../feed/executor/multi_thread_executor.cc | 2 + 2 files changed, 39 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc index 1957b695..29b0724a 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/dense_input_accessor.cc @@ -78,43 +78,6 @@ int32_t DenseInputAccessor::forward(SampleInstance* samples, size_t num, return 0; } -int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num, - paddle::framework::Scope* scope) { - if (!_need_gradient) { - return 0; - } - size_t data_buffer_idx = 0; - std::vector regions; - for (auto& variable : _x_variables) { - auto* tensor = scope->Var(variable.gradient_name)-> - GetMutable(); - auto* grad_data = tensor->mutable_data(_trainer_context->cpu_place); - regions.emplace_back(grad_data, variable.dim); - } - auto* ps_client = _trainer_context->pslib->ps_client(); - auto push_status = ps_client->push_dense(regions.data(), regions.size(), _table_id); - //push_status.get(); - if (!FLAGS_feed_trainer_debug_dense_name.empty()) { - std::stringstream ssm; - for (auto& variable : _x_variables) { - ssm.str(""); - if (variable.name != FLAGS_feed_trainer_debug_dense_name) { - continue; - } - auto& tensor = scope->Var(variable.gradient_name)-> - Get(); - const auto* var_data = tensor.data(); - for (size_t data_idx = 0; data_idx < variable.dim; ++data_idx) { - if (data_idx > 0) - ssm << ","; - ssm << var_data[data_idx]; - } - VLOG(2) << "[DEBUG]push_dense: " << ssm.str(); - } - } - return 0; -} - int32_t DenseInputAccessor::collect_persistables(paddle::framework::Scope* scope) { // 首次同步pull,之后异步pull if (_data_buffer == nullptr) { @@ -178,6 +141,43 @@ int32_t DenseInputAccessor::collect_persistables_name(std::vector& return 0; } +int32_t DenseInputAccessor::backward(SampleInstance* samples, size_t num, + paddle::framework::Scope* scope) { + if (!_need_gradient) { + return 0; + } + size_t data_buffer_idx = 0; + std::vector regions; + for (auto& variable : _x_variables) { + auto* tensor = scope->Var(variable.gradient_name)-> + GetMutable(); + auto* grad_data = tensor->mutable_data(_trainer_context->cpu_place); + regions.emplace_back(grad_data, variable.dim); + } + auto* ps_client = _trainer_context->pslib->ps_client(); + auto push_status = ps_client->push_dense(regions.data(), regions.size(), _table_id); + //push_status.get(); + if (!FLAGS_feed_trainer_debug_dense_name.empty()) { + std::stringstream ssm; + for (auto& variable : _x_variables) { + ssm.str(""); + if (variable.name != FLAGS_feed_trainer_debug_dense_name) { + continue; + } + auto& tensor = scope->Var(variable.gradient_name)-> + Get(); + const auto* var_data = tensor.data(); + for (size_t data_idx = 0; data_idx < variable.dim; ++data_idx) { + if (data_idx > 0) + ssm << ","; + ssm << var_data[data_idx]; + } + VLOG(2) << "[DEBUG]push_dense: " << ssm.str(); + } + } + return 0; +} + int32_t EbdVariableInputAccessor::forward(SampleInstance* samples, size_t num, paddle::framework::Scope* scope) { CHECK(_x_variables.size() == 1); diff --git a/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc b/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc index e7ef07b0..3e027589 100644 --- a/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc +++ b/paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.cc @@ -2,6 +2,8 @@ #include "paddle/fluid/train/custom_trainer/feed/io/file_system.h" #include "paddle/fluid/train/custom_trainer/feed/monitor/monitor.h" #include "paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/program_desc.h" namespace paddle { namespace custom_trainer { -- GitLab