diff --git a/paddle/cinn/frontend/decomposer/activation_test.cc b/paddle/cinn/frontend/decomposer/activation_test.cc index e0bd9a82a48e0e3cf0ab5ee9871ebba22aeb0bcb..78bcbf1ac6cf391cab13711fec580275e7141078 100644 --- a/paddle/cinn/frontend/decomposer/activation_test.cc +++ b/paddle/cinn/frontend/decomposer/activation_test.cc @@ -38,7 +38,7 @@ TEST(Decomposer, relu) { std::vector output_names = {out->id}; std::vector> output_shapes = {{20, 10}}; RunAndCheck( - builder, input_names, output_names, output_shapes, relu_cpu, -1, 1); + &builder, input_names, output_names, output_shapes, relu_cpu, -1, 1); } TEST(Decomposer, relu_grad) { @@ -62,7 +62,7 @@ TEST(Decomposer, relu_grad) { std::vector output_names = {dx->id}; std::vector> output_shapes = {{20, 10}}; RunAndCheck( - builder, input_names, output_names, output_shapes, relu_grad_cpu, -1, 1); + &builder, input_names, output_names, output_shapes, relu_grad_cpu, -1, 1); } TEST(Decomposer, softmax_decomposer) { diff --git a/paddle/cinn/frontend/decomposer/broadcast_test.cc b/paddle/cinn/frontend/decomposer/broadcast_test.cc index 93a58649219b545194dabf1be843dc47343c4667..5f936aab624916c064b5e178ff76e760ad2f6ae5 100644 --- a/paddle/cinn/frontend/decomposer/broadcast_test.cc +++ b/paddle/cinn/frontend/decomposer/broadcast_test.cc @@ -27,7 +27,7 @@ TEST(Decomposer, elementwise_add_bcast0) { std::vector input_names = {x.id().data(), y.id().data()}; std::vector output_names = {out->id}; std::vector> output_shapes = {{4, 10, 20, 10}}; - RunAndCheckShape(builder, input_names, output_names, output_shapes); + RunAndCheckShape(&builder, input_names, output_names, output_shapes); } TEST(Decomposer, elementwise_add_bcase1) { @@ -39,7 +39,7 @@ TEST(Decomposer, elementwise_add_bcase1) { std::vector input_names = {x.id().data(), y.id().data()}; std::vector output_names = {out->id}; std::vector> output_shapes = {{4, 10, 20, 10}}; - RunAndCheckShape(builder, input_names, output_names, output_shapes); + RunAndCheckShape(&builder, input_names, output_names, output_shapes); } TEST(Decomposer, elementwise_add_grad_bcast0) { @@ -52,7 +52,7 @@ TEST(Decomposer, elementwise_add_grad_bcast0) { std::vector input_names = {dout.id().data()}; std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector> output_shapes = {{4, 1, 20, 10}, {10, 20}}; - RunAndCheckShape(builder, input_names, output_names, output_shapes); + RunAndCheckShape(&builder, input_names, output_names, output_shapes); } TEST(Decomposer, elementwise_add_bcast1) { @@ -80,7 +80,7 @@ TEST(Decomposer, elementwise_add_bcast1) { std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 64, 32, 32}}; RunAndCheck( - builder, input_names, output_names, output_shapes, add_cpu); + &builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_bcast1_2) { @@ -108,7 +108,7 @@ TEST(Decomposer, elementwise_add_bcast1_2) { std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 64, 32, 32}}; RunAndCheck( - builder, input_names, output_names, output_shapes, add_cpu); + &builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_grad_bcast1) { @@ -140,7 +140,7 @@ TEST(Decomposer, elementwise_add_grad_bcast1) { std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector> output_shapes = {{32, 64, 32, 32}, {64}}; RunAndCheck( - builder, input_names, output_names, output_shapes, add_grad_cpu); + &builder, input_names, output_names, output_shapes, add_grad_cpu); } TEST(Decomposer, elementwise_add_bcast2) { @@ -165,7 +165,7 @@ TEST(Decomposer, elementwise_add_bcast2) { std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 16}}; RunAndCheck( - builder, input_names, output_names, output_shapes, add_cpu); + &builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_bcast2_2) { @@ -190,7 +190,7 @@ TEST(Decomposer, elementwise_add_bcast2_2) { std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 16}}; RunAndCheck( - builder, input_names, output_names, output_shapes, add_cpu); + &builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_bcast2_3) { @@ -217,7 +217,7 @@ TEST(Decomposer, elementwise_add_bcast2_3) { std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 16}}; RunAndCheck( - builder, input_names, output_names, output_shapes, add_cpu); + &builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_grad_bcast2) { @@ -244,7 +244,7 @@ TEST(Decomposer, elementwise_add_grad_bcast2) { std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector> output_shapes = {{32, 16}, {1}}; RunAndCheck( - builder, input_names, output_names, output_shapes, add_grad_cpu); + &builder, input_names, output_names, output_shapes, add_grad_cpu); } TEST(Decomposer, elementwise_add_same_dims) { @@ -268,7 +268,7 @@ TEST(Decomposer, elementwise_add_same_dims) { std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 16}}; RunAndCheck( - builder, input_names, output_names, output_shapes, add_cpu); + &builder, input_names, output_names, output_shapes, add_cpu); } TEST(Decomposer, elementwise_add_grad_same_dims) { @@ -295,7 +295,7 @@ TEST(Decomposer, elementwise_add_grad_same_dims) { std::vector output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector> output_shapes = {{32, 16}, {32, 16}}; RunAndCheck( - builder, input_names, output_names, output_shapes, add_grad_cpu); + &builder, input_names, output_names, output_shapes, add_grad_cpu); } } // namespace cinn::frontend diff --git a/paddle/cinn/frontend/decomposer/elementwise_test.cc b/paddle/cinn/frontend/decomposer/elementwise_test.cc index 6f02608ccc378fea14dfdccac672816aa6741f6f..83093840116e1ceefe1653d26640882a29bad7d7 100644 --- a/paddle/cinn/frontend/decomposer/elementwise_test.cc +++ b/paddle/cinn/frontend/decomposer/elementwise_test.cc @@ -42,7 +42,7 @@ TEST(Decomposer, sum) { std::vector output_names = {out->id}; std::vector> output_shapes = {{32, 16}}; RunAndCheck( - builder, input_names, output_names, output_shapes, sum_cpu); + &builder, input_names, output_names, output_shapes, sum_cpu); } } // namespace cinn::frontend diff --git a/paddle/cinn/frontend/decomposer/test_helper.h b/paddle/cinn/frontend/decomposer/test_helper.h index f2d9dddabda8b8d28d82edd059415d85f28cf77a..9188b4ee48a70c73d2c12b17dbaffe1f7d41bd4b 100644 --- a/paddle/cinn/frontend/decomposer/test_helper.h +++ b/paddle/cinn/frontend/decomposer/test_helper.h @@ -193,7 +193,7 @@ void RunDecomposer(Program* prog, const std::vector& fetch_ids = {}); template -void RunAndCheckShape(NetBuilder& builder, +void RunAndCheckShape(NetBuilder* builder, const std::vector& input_names, const std::vector& output_names, const std::vector>& output_shapes, @@ -202,7 +202,7 @@ void RunAndCheckShape(NetBuilder& builder, T low = 0, T high = 1, const std::vector& passes = {"Decomposer"}) { - auto prog = builder.Build(); + auto prog = builder->Build(); Target target = common::DefaultTarget(); RunDecomposer(&prog, target, passes, output_names); auto graph = std::make_shared(prog, target); @@ -238,7 +238,7 @@ void RunAndCheckShape(NetBuilder& builder, } template -void RunAndCheck(NetBuilder& builder, +void RunAndCheck(NetBuilder* builder, const std::vector& input_names, const std::vector& output_names, const std::vector>& output_shapes, diff --git a/paddle/cinn/frontend/pass/fill_constant_rewriter_test.cc b/paddle/cinn/frontend/pass/fill_constant_rewriter_test.cc index 7823e1d63e49383ee2c223b2b924837f99c291df..ee7b428b54414c899305f22a52aad0036ff4a82a 100644 --- a/paddle/cinn/frontend/pass/fill_constant_rewriter_test.cc +++ b/paddle/cinn/frontend/pass/fill_constant_rewriter_test.cc @@ -42,7 +42,7 @@ TEST(FillConstantRewriter, remove_reshape_single) { std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; int num_removed_ops = - tester.RunAndCheck(builder, program_passes, input_names, output_names); + tester.RunAndCheck(&builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 2); } @@ -68,7 +68,7 @@ TEST(FillConstantRewriter, remove_reshape_with_fill_constant) { std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; int num_removed_ops = - tester.RunAndCheck(builder, program_passes, input_names, output_names); + tester.RunAndCheck(&builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 2); } @@ -93,7 +93,7 @@ TEST(FillConstantRewriter, remove_scale_single) { std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; int num_removed_ops = - tester.RunAndCheck(builder, program_passes, input_names, output_names); + tester.RunAndCheck(&builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 2); } @@ -118,7 +118,7 @@ TEST(FillConstantRewriter, remove_scale_with_fill_constant) { std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; int num_removed_ops = - tester.RunAndCheck(builder, program_passes, input_names, output_names); + tester.RunAndCheck(&builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 2); } @@ -150,7 +150,7 @@ TEST(FillConstantRewriter, remove_multi_scale_with_fill_constant) { std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; int num_removed_ops = - tester.RunAndCheck(builder, program_passes, input_names, output_names); + tester.RunAndCheck(&builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 4); } @@ -167,7 +167,7 @@ TEST(FillConstantRewriter, two_fill_constant) { std::vector program_passes = {"FillConstantRewriter", "RemoveIdentity"}; int num_removed_ops = - tester.RunAndCheck(builder, program_passes, input_names, output_names); + tester.RunAndCheck(&builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 0); } diff --git a/paddle/cinn/frontend/pass/remove_identity_test.cc b/paddle/cinn/frontend/pass/remove_identity_test.cc index 13ad1e1a700198954726c2de134823851a8c5d25..f67bb27b388accfb3052fa267e725934602f8e5f 100644 --- a/paddle/cinn/frontend/pass/remove_identity_test.cc +++ b/paddle/cinn/frontend/pass/remove_identity_test.cc @@ -40,7 +40,7 @@ TEST(RemoveIdentity, remove_single) { std::vector program_passes = {"RemoveIdentity", "DeadCodeEliminate"}; int num_removed_ops = - tester.RunAndCheck(builder, program_passes, input_names, output_names); + tester.RunAndCheck(&builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 3); } @@ -63,7 +63,7 @@ TEST(RemoveIdentity, remove_branch) { std::vector output_names = {reduce_sum_1->id, reduce_sum_2->id}; std::vector program_passes = {"RemoveIdentity"}; int num_removed_ops = - tester.RunAndCheck(builder, program_passes, input_names, output_names); + tester.RunAndCheck(&builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 1); } @@ -92,7 +92,7 @@ TEST(RemoveIdentity, remove_multiple) { std::vector output_names = {mul_1->id}; std::vector program_passes = {"RemoveIdentity"}; int num_removed_ops = - tester.RunAndCheck(builder, program_passes, input_names, output_names); + tester.RunAndCheck(&builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 3); } @@ -121,7 +121,7 @@ TEST(RemoveIdentity, cannot_remove_fetch) { std::vector output_names = {identity_2->id, mul_1->id}; std::vector program_passes = {"RemoveIdentity"}; int num_removed_ops = - tester.RunAndCheck(builder, program_passes, input_names, output_names); + tester.RunAndCheck(&builder, program_passes, input_names, output_names); ASSERT_EQ(num_removed_ops, 1); } diff --git a/paddle/cinn/frontend/pass/test_helper.h b/paddle/cinn/frontend/pass/test_helper.h index ba5e0058e8e8844bd0c81fb609ce1e3d8c36e94c..468f15a164d9c12fa19bdf7425ff2eb2b581757b 100644 --- a/paddle/cinn/frontend/pass/test_helper.h +++ b/paddle/cinn/frontend/pass/test_helper.h @@ -75,11 +75,11 @@ class PassTest { public: PassTest() { target_ = common::DefaultTarget(); } - int RunAndCheck(NetBuilder& builder, + int RunAndCheck(NetBuilder* builder, const std::vector& program_passes, const std::vector& input_names, const std::vector& output_names) { - auto program = builder.Build(); + auto program = builder->Build(); CHECK(IsValid(program)) << "The origin program is not valid."; int origin_program_size = program.size(); LOG(INFO) << "Run origin program"; diff --git a/paddle/cinn/hlir/pass/fusion_merge_pass.cc b/paddle/cinn/hlir/pass/fusion_merge_pass.cc index dc09bd5c7b5723c14c86d9b00b0e6d79777e6737..fc0b372b3ede0fb59971745186888f76e3558f99 100644 --- a/paddle/cinn/hlir/pass/fusion_merge_pass.cc +++ b/paddle/cinn/hlir/pass/fusion_merge_pass.cc @@ -176,7 +176,7 @@ class FusionMergePassHelper : public FusionHelperBase { bool HorizontalFusion( GroupPtr producer, - std::unordered_set& consumers) { + const std::unordered_set& consumers) { VLOG(3) << "HorizontalFusion...!"; if (consumers.size() <= 1) { return false; @@ -249,7 +249,7 @@ class FusionMergePassHelper : public FusionHelperBase { return updated; } - void HorizontalFuse(GroupList& consumers) { + void HorizontalFuse(const GroupList& consumers) { VLOG(3) << "HorizontalFuse Groups..."; // create fusion group auto fused_group = std::make_shared(); @@ -400,8 +400,8 @@ class FusionMergePassHelper : public FusionHelperBase { } bool VerticalFusion( - GroupPtr& producer, - std::unordered_set& consumers, + const GroupPtr& producer, + const std::unordered_set& consumers, bool recompute) { VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); auto& relation = fusion_relation_map_[producer->op_pattern_kind]; @@ -463,14 +463,14 @@ class FusionMergePassHelper : public FusionHelperBase { if (!recompute) { return false; } else { - RecomputeEleGraph(producer, fuse_consumers_unsafe); + RecomputeEleGraph(producer, &fuse_consumers_unsafe); VerticalFuse(producer, fuse_consumers_unsafe); return true; } } if (fuse_consumers.size()) { - SelectConsumerToFuse(producer, fuse_consumers); + SelectConsumerToFuse(producer, &fuse_consumers); } // if fusionable consumers exist @@ -482,9 +482,9 @@ class FusionMergePassHelper : public FusionHelperBase { return false; } - void VerticalFuse( - GroupPtr& producer, - std::unordered_set& fusionable_consumers) { + void VerticalFuse(const GroupPtr& producer, + const std::unordered_set& + fusionable_consumers) { VLOG(3) << "VerticalFuse...!"; GroupList fused_groups; GroupPtr master_fuesd_group(nullptr); @@ -671,7 +671,7 @@ class FusionMergePassHelper : public FusionHelperBase { void RecomputeEleGraph( const GroupPtr& producer, - std::unordered_set& fusionable_consumers) { + std::unordered_set* fusionable_consumers) { if (producer->op_pattern_kind != framework::kElementWise) { SelectConsumerToFuse(producer, fusionable_consumers); } @@ -679,11 +679,11 @@ class FusionMergePassHelper : public FusionHelperBase { void SelectConsumerToFuse( const GroupPtr& producer, - std::unordered_set& fusionable_consumers) { + std::unordered_set* fusionable_consumers) { // if is const op if (is_const_group(this, producer)) { std::unordered_set candidates; - for (auto& consumer : fusionable_consumers) { + for (auto& consumer : *fusionable_consumers) { // if can be output node. if (is_same_shape(this, producer, consumer)) { candidates.insert(consumer); @@ -707,10 +707,10 @@ class FusionMergePassHelper : public FusionHelperBase { CHECK_GE(producer->consumer_groups.size(), candidates.size()); if (producer->consumer_groups.size() == 0 && candidates.size() == 0 && output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { - producer->belong_groups.insert(*fusionable_consumers.begin()); + producer->belong_groups.insert(*fusionable_consumers->begin()); } - fusionable_consumers = candidates; + *fusionable_consumers = candidates; return; } // 1 to 1 fusion. @@ -720,7 +720,7 @@ class FusionMergePassHelper : public FusionHelperBase { if (FLAGS_enhance_vertical_fusion_with_recompute) { std::vector candidates; - for (auto& consumer : fusionable_consumers) { + for (auto& consumer : *fusionable_consumers) { if (consumer->op_pattern_kind == framework::kElementWise) { candidates.push_back(consumer); continue; @@ -764,13 +764,13 @@ class FusionMergePassHelper : public FusionHelperBase { return lhs->op_pattern_kind < rhs->op_pattern_kind; }); - fusionable_consumers.clear(); + fusionable_consumers->clear(); if (candidates.size()) { - fusionable_consumers.insert(*candidates.begin()); + fusionable_consumers->insert(*candidates.begin()); } } else { std::unordered_set candidates; - for (auto& consumer : fusionable_consumers) { + for (auto& consumer : *fusionable_consumers) { if (consumer->op_pattern_kind == framework::kElementWise) { candidates.insert(consumer); continue; @@ -787,9 +787,9 @@ class FusionMergePassHelper : public FusionHelperBase { } } - fusionable_consumers.clear(); + fusionable_consumers->clear(); if (candidates.size()) { - fusionable_consumers.insert(*candidates.begin()); + fusionable_consumers->insert(*candidates.begin()); } } }