diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc index fd224de483b2e066fe13851eb03eba54f4b21a89..e63764ae91528e2b382b7bcfb8b71c8ac956523a 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -99,6 +99,7 @@ std::vector CodeGenerator::ConvertToExpressions( input_ids.push_back(-1); } } + // Output ids should be set in fixed order, like: // - dx, dy in backward operations std::vector output_ids; @@ -106,10 +107,6 @@ std::vector CodeGenerator::ConvertToExpressions( OperationMap::Instance().Get(op->Type()).output_names; for (auto& name : output_names) { - PADDLE_ENFORCE_EQ( - op->Output(name).size(), 1U, - platform::errors::InvalidArgument( - "Output(%s) of operation %s is not set.", name, op->Type())); PADDLE_ENFORCE_NE( var_ids.find(op->Output(name)[0]), var_ids.end(), platform::errors::InvalidArgument( diff --git a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc index c81e9c27b6b06a9849b98d292fac14285a31191a..6e61f10e091c6519bd68dab70cfaaeb2ac64734f 100644 --- a/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc +++ b/paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc @@ -111,6 +111,15 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { } } } + + auto op = n->Op(); + std::vector output_names = + OperationMap::Instance().Get(op->Type()).output_names; + + for (auto& name : output_names) { + if (op->Output(name).size() != 1) return false; + } + return true; } return false; diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py index d33658c89f1038d4deb9d289f25400793bd96fd8..aab789bf6399cdb435e6bfa896b9b23aed3dfe3c 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py @@ -109,11 +109,9 @@ class FusionGroupPassTest2(FusionGroupPassTest): tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3])) tmp_3 = layers.mul(tmp_1, tmp_2) - # TODO(wangchaochaohu): support the case when some vars are set - # stop_gradient = True. - + self.append_gradients(tmp_3) self.num_fused_ops = 2 - self.fetch_list = [tmp_3] + self.fetch_list = [tmp_3, self.grad(tmp_1)] class FusionGroupPassTestFP64(FusionGroupPassTest):