From 8c6dae777684d078fe1064e44091e5d5e9493220 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 14 May 2018 20:40:17 +0800 Subject: [PATCH] fix pe bug --- .../framework/details/multi_devices_graph_builder.cc | 8 ++++++++ .../fluid/framework/details/multi_devices_graph_builder.h | 2 ++ 2 files changed, 10 insertions(+) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 4755559f8..5473aa5b4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -145,12 +145,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } else if (IsDistTrainOp(*op, send_op)) { CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { + CreateComputationalOps(&result, *op, places_.size()); // user can customize loss@grad if not use_default_grad_scale_ if (use_default_grad_scale_) { CreateScaleLossGradOp(&result); } is_forwarding = false; } else { + if (IsScaleLossGradOp(*op)) continue; int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); if (op_dev_id == -1) { // var on all device CreateComputationalOps(&result, *op, places_.size()); @@ -399,6 +401,12 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, } bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { + // FIXME(yy): Do not hard code like this + return op.OutputArgumentNames().size() == 1 && + (op.OutputArgumentNames()[0]) == loss_var_name_; +} + +bool MultiDevSSAGraphBuilder::IsScaleLossGradOp(const OpDesc &op) const { // FIXME(yy): Do not hard code like this return op.OutputArgumentNames().size() == 1 && op.OutputArgumentNames()[0] == GradVarName(loss_var_name_); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 3a3e9e3b8..8a59079ac 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -67,6 +67,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; + bool IsScaleLossGradOp(const OpDesc &op) const; + void CreateSendOp(SSAGraph *result, const OpDesc &op) const; /** -- GitLab