From d085f79228389ca4984f66069349af8268edd8b8 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Wed, 8 Apr 2020 09:05:26 +0800 Subject: [PATCH] fix untime fail for output var stop_gradient=True for fusion group (#23317) --- paddle/fluid/framework/ir/fusion_group/code_generator.cc | 5 +---- .../ir/fusion_group/elementwise_group_detector.cc | 9 +++++++++ .../tests/unittests/ir/test_ir_fusion_group_pass.py | 6 ++---- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc index fd224de483..e63764ae91 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -99,6 +99,7 @@ std::vector CodeGenerator::ConvertToExpressions( input_ids.push_back(-1); } } + // Output ids should be set in fixed order, like: // - dx, dy in backward operations std::vector output_ids; @@ -106,10 +107,6 @@ std::vector CodeGenerator::ConvertToExpressions( OperationMap::Instance().Get(op->Type()).output_names; for (auto& name : output_names) { - PADDLE_ENFORCE_EQ( - op->Output(name).size(), 1U, - platform::errors::InvalidArgument( - "Output(%s) of operation %s is not set.", name, op->Type())); PADDLE_ENFORCE_NE( var_ids.find(op->Output(name)[0]), var_ids.end(), platform::errors::InvalidArgument( diff --git a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc index c81e9c27b6..6e61f10e09 100644 --- a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc +++ b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc @@ -111,6 +111,15 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { } } } + + auto op = n->Op(); + std::vector output_names = + OperationMap::Instance().Get(op->Type()).output_names; + + for (auto& name : output_names) { + if (op->Output(name).size() != 1) return false; + } + return true; } return false; diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py index d33658c89f..aab789bf63 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py @@ -109,11 +109,9 @@ class FusionGroupPassTest2(FusionGroupPassTest): tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3])) tmp_3 = layers.mul(tmp_1, tmp_2) - # TODO(wangchaochaohu): support the case when some vars are set - # stop_gradient = True. - + self.append_gradients(tmp_3) self.num_fused_ops = 2 - self.fetch_list = [tmp_3] + self.fetch_list = [tmp_3, self.grad(tmp_1)] class FusionGroupPassTestFP64(FusionGroupPassTest): -- GitLab