diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index c2eb1c31b4f5625e662436e278a33c55b38bb004..725dc57b047acfdc9d88a1502dcbc7be92679f3f 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -45,7 +45,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, - const std::vector &local_scopes, bool skip_scale_loss) + const std::vector &local_scopes, bool use_default_grad_scale) : loss_var_name_(loss_var_name), places_(places), local_scopes_(local_scopes) { @@ -53,7 +53,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( for (auto &p : params) { grad_names_.insert(GradVarName(p)); } - skip_scale_loss_ = skip_scale_loss; + use_default_grad_scale_ = use_default_grad_scale; } void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, @@ -126,8 +126,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } else if (IsDistTrainOp(*op, send_op)) { CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { - // user can customize loss@grad if skip_scale_loss_ - if (!skip_scale_loss_) { + // user can customize loss@grad if not use_default_grad_scale_ + if (use_default_grad_scale_) { CreateScaleLossGradOp(&result); } is_forwarding = false; diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index fa4d31bdc49da5d30340a710c950dcc8cd70180b..bad47458ef4cd1cd42e902341e8be66da5c210ed 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -41,7 +41,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, - bool skip_scale_loss); + bool use_default_grad_scale); #endif std::unique_ptr Build(const ProgramDesc &program) const override; @@ -59,7 +59,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { #ifdef PADDLE_WITH_CUDA platform::NCCLContextMap *nccl_ctxs_; #endif - bool skip_scale_loss_; + bool use_default_grad_scale_; bool IsScaleLossOp(const OpDesc &op) const;