未验证 提交 f3e4e42d 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #10130 from reyoung/feature/skip_loss

Add customize_loss_grad option to PE
...@@ -34,7 +34,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -34,7 +34,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes, bool skip_scale_loss,
platform::NCCLContextMap *nccl_ctxs) platform::NCCLContextMap *nccl_ctxs)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
...@@ -45,7 +45,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -45,7 +45,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes) const std::vector<Scope *> &local_scopes, bool skip_scale_loss)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes) { local_scopes_(local_scopes) {
...@@ -53,6 +53,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -53,6 +53,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
skip_scale_loss_ = skip_scale_loss;
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
...@@ -95,7 +96,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -95,7 +96,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// always use the first device // always use the first device
CreateSendOp(&result, *op); CreateSendOp(&result, *op);
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
if (!skip_scale_loss_) {
CreateScaleLossGradOp(&result); CreateScaleLossGradOp(&result);
}
is_forwarding = false; is_forwarding = false;
} else { } else {
CreateComputationalOps(&result, *op); CreateComputationalOps(&result, *op);
......
...@@ -34,12 +34,14 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -34,12 +34,14 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
bool skip_scale_loss,
platform::NCCLContextMap *nccl_ctxs); platform::NCCLContextMap *nccl_ctxs);
#else #else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places, MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes); const std::vector<Scope *> &local_scopes,
bool skip_scale_loss);
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
...@@ -57,6 +59,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -57,6 +59,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_; platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
bool skip_scale_loss_;
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
......
...@@ -57,7 +57,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -57,7 +57,8 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay) Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
bool customize_scale_loss)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
...@@ -90,12 +91,13 @@ ParallelExecutor::ParallelExecutor( ...@@ -90,12 +91,13 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, details::MultiDevSSAGraphBuilder builder(
params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
member_->nccl_ctxs_.get()); customize_scale_loss, member_->nccl_ctxs_.get());
#else #else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_); params, member_->local_scopes_,
customize_scale_loss);
#endif #endif
auto graph = builder.Build(main_program); auto graph = builder.Build(main_program);
......
...@@ -40,7 +40,7 @@ class ParallelExecutor { ...@@ -40,7 +40,7 @@ class ParallelExecutor {
const ProgramDesc& main_program, const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope, const std::string& loss_var_name, Scope* scope,
const std::vector<Scope*>& local_scopes, const std::vector<Scope*>& local_scopes,
bool allow_op_delay); bool allow_op_delay, bool customize_scale_loss);
~ParallelExecutor(); ~ParallelExecutor();
......
...@@ -502,11 +502,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -502,11 +502,11 @@ All parameter, weight, gradient are variables in Paddle.
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, std::vector<Scope *> &local_scopes, Scope *scope, std::vector<Scope *> &local_scopes,
bool allow_op_delay) { bool allow_op_delay, bool customize_loss_grad) {
new (&self) new (&self) ParallelExecutor(num_threads, use_event, places,
ParallelExecutor(num_threads, use_event, places, params, params, bcast_vars, main_program,
bcast_vars, main_program, loss_var_name, loss_var_name, scope, local_scopes,
scope, local_scopes, allow_op_delay); allow_op_delay, customize_loss_grad);
}) })
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
// NOTE: even we return a vec<Scope*>* to Python use reference policy. // NOTE: even we return a vec<Scope*>* to Python use reference policy.
......
...@@ -29,7 +29,8 @@ class ParallelExecutor(object): ...@@ -29,7 +29,8 @@ class ParallelExecutor(object):
main_program=None, main_program=None,
num_threads=None, num_threads=None,
allow_op_delay=False, allow_op_delay=False,
share_vars_from=None): share_vars_from=None,
customize_loss_grad=False):
""" """
ParallelExecutor can run program in parallel. ParallelExecutor can run program in parallel.
...@@ -122,7 +123,8 @@ class ParallelExecutor(object): ...@@ -122,7 +123,8 @@ class ParallelExecutor(object):
loss_name if loss_name else '', loss_name if loss_name else '',
scope, scope,
local_scopes, local_scopes,
allow_op_delay) allow_op_delay,
customize_loss_grad)
self.scope = scope self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None): def run(self, fetch_list, feed=None, feed_dict=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册