From c0ac0cd6b37732690168e8fa311242fd25f095bd Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 28 Apr 2018 14:07:00 +0800 Subject: [PATCH] Complete rename --- .../framework/details/multi_devices_graph_builder.cc | 8 ++++---- .../fluid/framework/details/multi_devices_graph_builder.h | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index c2eb1c31b4..725dc57b04 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -45,7 +45,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, - const std::vector &local_scopes, bool skip_scale_loss) + const std::vector &local_scopes, bool use_default_grad_scale) : loss_var_name_(loss_var_name), places_(places), local_scopes_(local_scopes) { @@ -53,7 +53,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( for (auto &p : params) { grad_names_.insert(GradVarName(p)); } - skip_scale_loss_ = skip_scale_loss; + use_default_grad_scale_ = use_default_grad_scale; } void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, @@ -126,8 +126,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } else if (IsDistTrainOp(*op, send_op)) { CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { - // user can customize loss@grad if skip_scale_loss_ - if (!skip_scale_loss_) { + // user can customize loss@grad if not use_default_grad_scale_ + if (use_default_grad_scale_) { CreateScaleLossGradOp(&result); } is_forwarding = false; diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index fa4d31bdc4..bad47458ef 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -41,7 +41,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, - bool skip_scale_loss); + bool use_default_grad_scale); #endif std::unique_ptr Build(const ProgramDesc &program) const override; @@ -59,7 +59,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { #ifdef PADDLE_WITH_CUDA platform::NCCLContextMap *nccl_ctxs_; #endif - bool skip_scale_loss_; + bool use_default_grad_scale_; bool IsScaleLossOp(const OpDesc &op) const; -- GitLab