diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 4755559f8d0c5b5fdeb6b56a28fff8a32ea7f82f..5473aa5b4673eeb99f46e0c909b38f590bf31c53 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -145,12 +145,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } else if (IsDistTrainOp(*op, send_op)) { CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { + CreateComputationalOps(&result, *op, places_.size()); // user can customize loss@grad if not use_default_grad_scale_ if (use_default_grad_scale_) { CreateScaleLossGradOp(&result); } is_forwarding = false; } else { + if (IsScaleLossGradOp(*op)) continue; int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); if (op_dev_id == -1) { // var on all device CreateComputationalOps(&result, *op, places_.size()); @@ -399,6 +401,12 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, } bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { + // FIXME(yy): Do not hard code like this + return op.OutputArgumentNames().size() == 1 && + (op.OutputArgumentNames()[0]) == loss_var_name_; +} + +bool MultiDevSSAGraphBuilder::IsScaleLossGradOp(const OpDesc &op) const { // FIXME(yy): Do not hard code like this return op.OutputArgumentNames().size() == 1 && op.OutputArgumentNames()[0] == GradVarName(loss_var_name_); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 3a3e9e3b8538f52962e6a5ccd1a177e58d6c2f6b..8a59079ac309301793861888a479d45b203d3eed 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -67,6 +67,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; + bool IsScaleLossGradOp(const OpDesc &op) const; + void CreateSendOp(SSAGraph *result, const OpDesc &op) const; /**