未验证 提交 e2009545 编写于 作者: Z zyfncg 提交者: GitHub

[CINN] Adjust the code format in cinn (#55009)

* adjust the code format in cinn

* fix merge conflict
上级 e5725680
...@@ -38,7 +38,7 @@ TEST(Decomposer, relu) { ...@@ -38,7 +38,7 @@ TEST(Decomposer, relu) {
std::vector<std::string> output_names = {out->id}; std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{20, 10}}; std::vector<std::vector<int>> output_shapes = {{20, 10}};
RunAndCheck<float>( RunAndCheck<float>(
builder, input_names, output_names, output_shapes, relu_cpu, -1, 1); &builder, input_names, output_names, output_shapes, relu_cpu, -1, 1);
} }
TEST(Decomposer, relu_grad) { TEST(Decomposer, relu_grad) {
...@@ -62,7 +62,7 @@ TEST(Decomposer, relu_grad) { ...@@ -62,7 +62,7 @@ TEST(Decomposer, relu_grad) {
std::vector<std::string> output_names = {dx->id}; std::vector<std::string> output_names = {dx->id};
std::vector<std::vector<int>> output_shapes = {{20, 10}}; std::vector<std::vector<int>> output_shapes = {{20, 10}};
RunAndCheck<float>( RunAndCheck<float>(
builder, input_names, output_names, output_shapes, relu_grad_cpu, -1, 1); &builder, input_names, output_names, output_shapes, relu_grad_cpu, -1, 1);
} }
TEST(Decomposer, softmax_decomposer) { TEST(Decomposer, softmax_decomposer) {
......
...@@ -27,7 +27,7 @@ TEST(Decomposer, elementwise_add_bcast0) { ...@@ -27,7 +27,7 @@ TEST(Decomposer, elementwise_add_bcast0) {
std::vector<std::string> input_names = {x.id().data(), y.id().data()}; std::vector<std::string> input_names = {x.id().data(), y.id().data()};
std::vector<std::string> output_names = {out->id}; std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{4, 10, 20, 10}}; std::vector<std::vector<int>> output_shapes = {{4, 10, 20, 10}};
RunAndCheckShape<float>(builder, input_names, output_names, output_shapes); RunAndCheckShape<float>(&builder, input_names, output_names, output_shapes);
} }
TEST(Decomposer, elementwise_add_bcase1) { TEST(Decomposer, elementwise_add_bcase1) {
...@@ -39,7 +39,7 @@ TEST(Decomposer, elementwise_add_bcase1) { ...@@ -39,7 +39,7 @@ TEST(Decomposer, elementwise_add_bcase1) {
std::vector<std::string> input_names = {x.id().data(), y.id().data()}; std::vector<std::string> input_names = {x.id().data(), y.id().data()};
std::vector<std::string> output_names = {out->id}; std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{4, 10, 20, 10}}; std::vector<std::vector<int>> output_shapes = {{4, 10, 20, 10}};
RunAndCheckShape<float>(builder, input_names, output_names, output_shapes); RunAndCheckShape<float>(&builder, input_names, output_names, output_shapes);
} }
TEST(Decomposer, elementwise_add_grad_bcast0) { TEST(Decomposer, elementwise_add_grad_bcast0) {
...@@ -52,7 +52,7 @@ TEST(Decomposer, elementwise_add_grad_bcast0) { ...@@ -52,7 +52,7 @@ TEST(Decomposer, elementwise_add_grad_bcast0) {
std::vector<std::string> input_names = {dout.id().data()}; std::vector<std::string> input_names = {dout.id().data()};
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
std::vector<std::vector<int>> output_shapes = {{4, 1, 20, 10}, {10, 20}}; std::vector<std::vector<int>> output_shapes = {{4, 1, 20, 10}, {10, 20}};
RunAndCheckShape<float>(builder, input_names, output_names, output_shapes); RunAndCheckShape<float>(&builder, input_names, output_names, output_shapes);
} }
TEST(Decomposer, elementwise_add_bcast1) { TEST(Decomposer, elementwise_add_bcast1) {
...@@ -80,7 +80,7 @@ TEST(Decomposer, elementwise_add_bcast1) { ...@@ -80,7 +80,7 @@ TEST(Decomposer, elementwise_add_bcast1) {
std::vector<std::string> output_names = {out->id}; std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}}; std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}};
RunAndCheck<float>( RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_cpu); &builder, input_names, output_names, output_shapes, add_cpu);
} }
TEST(Decomposer, elementwise_add_bcast1_2) { TEST(Decomposer, elementwise_add_bcast1_2) {
...@@ -108,7 +108,7 @@ TEST(Decomposer, elementwise_add_bcast1_2) { ...@@ -108,7 +108,7 @@ TEST(Decomposer, elementwise_add_bcast1_2) {
std::vector<std::string> output_names = {out->id}; std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}}; std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}};
RunAndCheck<float>( RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_cpu); &builder, input_names, output_names, output_shapes, add_cpu);
} }
TEST(Decomposer, elementwise_add_grad_bcast1) { TEST(Decomposer, elementwise_add_grad_bcast1) {
...@@ -140,7 +140,7 @@ TEST(Decomposer, elementwise_add_grad_bcast1) { ...@@ -140,7 +140,7 @@ TEST(Decomposer, elementwise_add_grad_bcast1) {
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}, {64}}; std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}, {64}};
RunAndCheck<float>( RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_grad_cpu); &builder, input_names, output_names, output_shapes, add_grad_cpu);
} }
TEST(Decomposer, elementwise_add_bcast2) { TEST(Decomposer, elementwise_add_bcast2) {
...@@ -165,7 +165,7 @@ TEST(Decomposer, elementwise_add_bcast2) { ...@@ -165,7 +165,7 @@ TEST(Decomposer, elementwise_add_bcast2) {
std::vector<std::string> output_names = {out->id}; std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}}; std::vector<std::vector<int>> output_shapes = {{32, 16}};
RunAndCheck<float>( RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_cpu); &builder, input_names, output_names, output_shapes, add_cpu);
} }
TEST(Decomposer, elementwise_add_bcast2_2) { TEST(Decomposer, elementwise_add_bcast2_2) {
...@@ -190,7 +190,7 @@ TEST(Decomposer, elementwise_add_bcast2_2) { ...@@ -190,7 +190,7 @@ TEST(Decomposer, elementwise_add_bcast2_2) {
std::vector<std::string> output_names = {out->id}; std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}}; std::vector<std::vector<int>> output_shapes = {{32, 16}};
RunAndCheck<float>( RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_cpu); &builder, input_names, output_names, output_shapes, add_cpu);
} }
TEST(Decomposer, elementwise_add_bcast2_3) { TEST(Decomposer, elementwise_add_bcast2_3) {
...@@ -217,7 +217,7 @@ TEST(Decomposer, elementwise_add_bcast2_3) { ...@@ -217,7 +217,7 @@ TEST(Decomposer, elementwise_add_bcast2_3) {
std::vector<std::string> output_names = {out->id}; std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}}; std::vector<std::vector<int>> output_shapes = {{32, 16}};
RunAndCheck<int_ty>( RunAndCheck<int_ty>(
builder, input_names, output_names, output_shapes, add_cpu); &builder, input_names, output_names, output_shapes, add_cpu);
} }
TEST(Decomposer, elementwise_add_grad_bcast2) { TEST(Decomposer, elementwise_add_grad_bcast2) {
...@@ -244,7 +244,7 @@ TEST(Decomposer, elementwise_add_grad_bcast2) { ...@@ -244,7 +244,7 @@ TEST(Decomposer, elementwise_add_grad_bcast2) {
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}, {1}}; std::vector<std::vector<int>> output_shapes = {{32, 16}, {1}};
RunAndCheck<float>( RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_grad_cpu); &builder, input_names, output_names, output_shapes, add_grad_cpu);
} }
TEST(Decomposer, elementwise_add_same_dims) { TEST(Decomposer, elementwise_add_same_dims) {
...@@ -268,7 +268,7 @@ TEST(Decomposer, elementwise_add_same_dims) { ...@@ -268,7 +268,7 @@ TEST(Decomposer, elementwise_add_same_dims) {
std::vector<std::string> output_names = {out->id}; std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}}; std::vector<std::vector<int>> output_shapes = {{32, 16}};
RunAndCheck<float>( RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_cpu); &builder, input_names, output_names, output_shapes, add_cpu);
} }
TEST(Decomposer, elementwise_add_grad_same_dims) { TEST(Decomposer, elementwise_add_grad_same_dims) {
...@@ -295,7 +295,7 @@ TEST(Decomposer, elementwise_add_grad_same_dims) { ...@@ -295,7 +295,7 @@ TEST(Decomposer, elementwise_add_grad_same_dims) {
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id}; std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}, {32, 16}}; std::vector<std::vector<int>> output_shapes = {{32, 16}, {32, 16}};
RunAndCheck<float>( RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_grad_cpu); &builder, input_names, output_names, output_shapes, add_grad_cpu);
} }
} // namespace cinn::frontend } // namespace cinn::frontend
...@@ -42,7 +42,7 @@ TEST(Decomposer, sum) { ...@@ -42,7 +42,7 @@ TEST(Decomposer, sum) {
std::vector<std::string> output_names = {out->id}; std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}}; std::vector<std::vector<int>> output_shapes = {{32, 16}};
RunAndCheck<float>( RunAndCheck<float>(
builder, input_names, output_names, output_shapes, sum_cpu); &builder, input_names, output_names, output_shapes, sum_cpu);
} }
} // namespace cinn::frontend } // namespace cinn::frontend
...@@ -193,7 +193,7 @@ void RunDecomposer(Program* prog, ...@@ -193,7 +193,7 @@ void RunDecomposer(Program* prog,
const std::vector<std::string>& fetch_ids = {}); const std::vector<std::string>& fetch_ids = {});
template <typename T> template <typename T>
void RunAndCheckShape(NetBuilder& builder, void RunAndCheckShape(NetBuilder* builder,
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
const std::vector<std::vector<int>>& output_shapes, const std::vector<std::vector<int>>& output_shapes,
...@@ -202,7 +202,7 @@ void RunAndCheckShape(NetBuilder& builder, ...@@ -202,7 +202,7 @@ void RunAndCheckShape(NetBuilder& builder,
T low = 0, T low = 0,
T high = 1, T high = 1,
const std::vector<std::string>& passes = {"Decomposer"}) { const std::vector<std::string>& passes = {"Decomposer"}) {
auto prog = builder.Build(); auto prog = builder->Build();
Target target = common::DefaultTarget(); Target target = common::DefaultTarget();
RunDecomposer(&prog, target, passes, output_names); RunDecomposer(&prog, target, passes, output_names);
auto graph = std::make_shared<hlir::framework::Graph>(prog, target); auto graph = std::make_shared<hlir::framework::Graph>(prog, target);
...@@ -238,7 +238,7 @@ void RunAndCheckShape(NetBuilder& builder, ...@@ -238,7 +238,7 @@ void RunAndCheckShape(NetBuilder& builder,
} }
template <typename T> template <typename T>
void RunAndCheck(NetBuilder& builder, void RunAndCheck(NetBuilder* builder,
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
const std::vector<std::vector<int>>& output_shapes, const std::vector<std::vector<int>>& output_shapes,
......
...@@ -42,7 +42,7 @@ TEST(FillConstantRewriter, remove_reshape_single) { ...@@ -42,7 +42,7 @@ TEST(FillConstantRewriter, remove_reshape_single) {
std::vector<std::string> program_passes = {"FillConstantRewriter", std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"}; "RemoveIdentity"};
int num_removed_ops = int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names); tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 2); ASSERT_EQ(num_removed_ops, 2);
} }
...@@ -68,7 +68,7 @@ TEST(FillConstantRewriter, remove_reshape_with_fill_constant) { ...@@ -68,7 +68,7 @@ TEST(FillConstantRewriter, remove_reshape_with_fill_constant) {
std::vector<std::string> program_passes = {"FillConstantRewriter", std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"}; "RemoveIdentity"};
int num_removed_ops = int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names); tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 2); ASSERT_EQ(num_removed_ops, 2);
} }
...@@ -93,7 +93,7 @@ TEST(FillConstantRewriter, remove_scale_single) { ...@@ -93,7 +93,7 @@ TEST(FillConstantRewriter, remove_scale_single) {
std::vector<std::string> program_passes = {"FillConstantRewriter", std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"}; "RemoveIdentity"};
int num_removed_ops = int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names); tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 2); ASSERT_EQ(num_removed_ops, 2);
} }
...@@ -118,7 +118,7 @@ TEST(FillConstantRewriter, remove_scale_with_fill_constant) { ...@@ -118,7 +118,7 @@ TEST(FillConstantRewriter, remove_scale_with_fill_constant) {
std::vector<std::string> program_passes = {"FillConstantRewriter", std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"}; "RemoveIdentity"};
int num_removed_ops = int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names); tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 2); ASSERT_EQ(num_removed_ops, 2);
} }
...@@ -150,7 +150,7 @@ TEST(FillConstantRewriter, remove_multi_scale_with_fill_constant) { ...@@ -150,7 +150,7 @@ TEST(FillConstantRewriter, remove_multi_scale_with_fill_constant) {
std::vector<std::string> program_passes = {"FillConstantRewriter", std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"}; "RemoveIdentity"};
int num_removed_ops = int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names); tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 4); ASSERT_EQ(num_removed_ops, 4);
} }
...@@ -167,7 +167,7 @@ TEST(FillConstantRewriter, two_fill_constant) { ...@@ -167,7 +167,7 @@ TEST(FillConstantRewriter, two_fill_constant) {
std::vector<std::string> program_passes = {"FillConstantRewriter", std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"}; "RemoveIdentity"};
int num_removed_ops = int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names); tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 0); ASSERT_EQ(num_removed_ops, 0);
} }
......
...@@ -40,7 +40,7 @@ TEST(RemoveIdentity, remove_single) { ...@@ -40,7 +40,7 @@ TEST(RemoveIdentity, remove_single) {
std::vector<std::string> program_passes = {"RemoveIdentity", std::vector<std::string> program_passes = {"RemoveIdentity",
"DeadCodeEliminate"}; "DeadCodeEliminate"};
int num_removed_ops = int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names); tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 3); ASSERT_EQ(num_removed_ops, 3);
} }
...@@ -63,7 +63,7 @@ TEST(RemoveIdentity, remove_branch) { ...@@ -63,7 +63,7 @@ TEST(RemoveIdentity, remove_branch) {
std::vector<std::string> output_names = {reduce_sum_1->id, reduce_sum_2->id}; std::vector<std::string> output_names = {reduce_sum_1->id, reduce_sum_2->id};
std::vector<std::string> program_passes = {"RemoveIdentity"}; std::vector<std::string> program_passes = {"RemoveIdentity"};
int num_removed_ops = int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names); tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 1); ASSERT_EQ(num_removed_ops, 1);
} }
...@@ -92,7 +92,7 @@ TEST(RemoveIdentity, remove_multiple) { ...@@ -92,7 +92,7 @@ TEST(RemoveIdentity, remove_multiple) {
std::vector<std::string> output_names = {mul_1->id}; std::vector<std::string> output_names = {mul_1->id};
std::vector<std::string> program_passes = {"RemoveIdentity"}; std::vector<std::string> program_passes = {"RemoveIdentity"};
int num_removed_ops = int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names); tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 3); ASSERT_EQ(num_removed_ops, 3);
} }
...@@ -121,7 +121,7 @@ TEST(RemoveIdentity, cannot_remove_fetch) { ...@@ -121,7 +121,7 @@ TEST(RemoveIdentity, cannot_remove_fetch) {
std::vector<std::string> output_names = {identity_2->id, mul_1->id}; std::vector<std::string> output_names = {identity_2->id, mul_1->id};
std::vector<std::string> program_passes = {"RemoveIdentity"}; std::vector<std::string> program_passes = {"RemoveIdentity"};
int num_removed_ops = int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names); tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 1); ASSERT_EQ(num_removed_ops, 1);
} }
......
...@@ -75,11 +75,11 @@ class PassTest { ...@@ -75,11 +75,11 @@ class PassTest {
public: public:
PassTest() { target_ = common::DefaultTarget(); } PassTest() { target_ = common::DefaultTarget(); }
int RunAndCheck(NetBuilder& builder, int RunAndCheck(NetBuilder* builder,
const std::vector<std::string>& program_passes, const std::vector<std::string>& program_passes,
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names) { const std::vector<std::string>& output_names) {
auto program = builder.Build(); auto program = builder->Build();
CHECK(IsValid(program)) << "The origin program is not valid."; CHECK(IsValid(program)) << "The origin program is not valid.";
int origin_program_size = program.size(); int origin_program_size = program.size();
LOG(INFO) << "Run origin program"; LOG(INFO) << "Run origin program";
......
...@@ -176,7 +176,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -176,7 +176,7 @@ class FusionMergePassHelper : public FusionHelperBase {
bool HorizontalFusion( bool HorizontalFusion(
GroupPtr producer, GroupPtr producer,
std::unordered_set<GroupPtr, Hasher, Comparator>& consumers) { const std::unordered_set<GroupPtr, Hasher, Comparator>& consumers) {
VLOG(3) << "HorizontalFusion...!"; VLOG(3) << "HorizontalFusion...!";
if (consumers.size() <= 1) { if (consumers.size() <= 1) {
return false; return false;
...@@ -249,7 +249,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -249,7 +249,7 @@ class FusionMergePassHelper : public FusionHelperBase {
return updated; return updated;
} }
void HorizontalFuse(GroupList& consumers) { void HorizontalFuse(const GroupList& consumers) {
VLOG(3) << "HorizontalFuse Groups..."; VLOG(3) << "HorizontalFuse Groups...";
// create fusion group // create fusion group
auto fused_group = std::make_shared<Graph::Group>(); auto fused_group = std::make_shared<Graph::Group>();
...@@ -400,8 +400,8 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -400,8 +400,8 @@ class FusionMergePassHelper : public FusionHelperBase {
} }
bool VerticalFusion( bool VerticalFusion(
GroupPtr& producer, const GroupPtr& producer,
std::unordered_set<GroupPtr, Hasher, Comparator>& consumers, const std::unordered_set<GroupPtr, Hasher, Comparator>& consumers,
bool recompute) { bool recompute) {
VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size();
auto& relation = fusion_relation_map_[producer->op_pattern_kind]; auto& relation = fusion_relation_map_[producer->op_pattern_kind];
...@@ -463,14 +463,14 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -463,14 +463,14 @@ class FusionMergePassHelper : public FusionHelperBase {
if (!recompute) { if (!recompute) {
return false; return false;
} else { } else {
RecomputeEleGraph(producer, fuse_consumers_unsafe); RecomputeEleGraph(producer, &fuse_consumers_unsafe);
VerticalFuse(producer, fuse_consumers_unsafe); VerticalFuse(producer, fuse_consumers_unsafe);
return true; return true;
} }
} }
if (fuse_consumers.size()) { if (fuse_consumers.size()) {
SelectConsumerToFuse(producer, fuse_consumers); SelectConsumerToFuse(producer, &fuse_consumers);
} }
// if fusionable consumers exist // if fusionable consumers exist
...@@ -482,9 +482,9 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -482,9 +482,9 @@ class FusionMergePassHelper : public FusionHelperBase {
return false; return false;
} }
void VerticalFuse( void VerticalFuse(const GroupPtr& producer,
GroupPtr& producer, const std::unordered_set<GroupPtr, Hasher, Comparator>&
std::unordered_set<GroupPtr, Hasher, Comparator>& fusionable_consumers) { fusionable_consumers) {
VLOG(3) << "VerticalFuse...!"; VLOG(3) << "VerticalFuse...!";
GroupList fused_groups; GroupList fused_groups;
GroupPtr master_fuesd_group(nullptr); GroupPtr master_fuesd_group(nullptr);
...@@ -671,7 +671,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -671,7 +671,7 @@ class FusionMergePassHelper : public FusionHelperBase {
void RecomputeEleGraph( void RecomputeEleGraph(
const GroupPtr& producer, const GroupPtr& producer,
std::unordered_set<GroupPtr, Hasher, Comparator>& fusionable_consumers) { std::unordered_set<GroupPtr, Hasher, Comparator>* fusionable_consumers) {
if (producer->op_pattern_kind != framework::kElementWise) { if (producer->op_pattern_kind != framework::kElementWise) {
SelectConsumerToFuse(producer, fusionable_consumers); SelectConsumerToFuse(producer, fusionable_consumers);
} }
...@@ -679,11 +679,11 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -679,11 +679,11 @@ class FusionMergePassHelper : public FusionHelperBase {
void SelectConsumerToFuse( void SelectConsumerToFuse(
const GroupPtr& producer, const GroupPtr& producer,
std::unordered_set<GroupPtr, Hasher, Comparator>& fusionable_consumers) { std::unordered_set<GroupPtr, Hasher, Comparator>* fusionable_consumers) {
// if is const op // if is const op
if (is_const_group(this, producer)) { if (is_const_group(this, producer)) {
std::unordered_set<GroupPtr, Hasher, Comparator> candidates; std::unordered_set<GroupPtr, Hasher, Comparator> candidates;
for (auto& consumer : fusionable_consumers) { for (auto& consumer : *fusionable_consumers) {
// if can be output node. // if can be output node.
if (is_same_shape(this, producer, consumer)) { if (is_same_shape(this, producer, consumer)) {
candidates.insert(consumer); candidates.insert(consumer);
...@@ -707,10 +707,10 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -707,10 +707,10 @@ class FusionMergePassHelper : public FusionHelperBase {
CHECK_GE(producer->consumer_groups.size(), candidates.size()); CHECK_GE(producer->consumer_groups.size(), candidates.size());
if (producer->consumer_groups.size() == 0 && candidates.size() == 0 && if (producer->consumer_groups.size() == 0 && candidates.size() == 0 &&
output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { output_nodes_set_.count(producer->CollectNodes()[0]) == 0) {
producer->belong_groups.insert(*fusionable_consumers.begin()); producer->belong_groups.insert(*fusionable_consumers->begin());
} }
fusionable_consumers = candidates; *fusionable_consumers = candidates;
return; return;
} }
// 1 to 1 fusion. // 1 to 1 fusion.
...@@ -720,7 +720,7 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -720,7 +720,7 @@ class FusionMergePassHelper : public FusionHelperBase {
if (FLAGS_enhance_vertical_fusion_with_recompute) { if (FLAGS_enhance_vertical_fusion_with_recompute) {
std::vector<GroupPtr> candidates; std::vector<GroupPtr> candidates;
for (auto& consumer : fusionable_consumers) { for (auto& consumer : *fusionable_consumers) {
if (consumer->op_pattern_kind == framework::kElementWise) { if (consumer->op_pattern_kind == framework::kElementWise) {
candidates.push_back(consumer); candidates.push_back(consumer);
continue; continue;
...@@ -764,13 +764,13 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -764,13 +764,13 @@ class FusionMergePassHelper : public FusionHelperBase {
return lhs->op_pattern_kind < rhs->op_pattern_kind; return lhs->op_pattern_kind < rhs->op_pattern_kind;
}); });
fusionable_consumers.clear(); fusionable_consumers->clear();
if (candidates.size()) { if (candidates.size()) {
fusionable_consumers.insert(*candidates.begin()); fusionable_consumers->insert(*candidates.begin());
} }
} else { } else {
std::unordered_set<GroupPtr, Hasher, Comparator> candidates; std::unordered_set<GroupPtr, Hasher, Comparator> candidates;
for (auto& consumer : fusionable_consumers) { for (auto& consumer : *fusionable_consumers) {
if (consumer->op_pattern_kind == framework::kElementWise) { if (consumer->op_pattern_kind == framework::kElementWise) {
candidates.insert(consumer); candidates.insert(consumer);
continue; continue;
...@@ -787,9 +787,9 @@ class FusionMergePassHelper : public FusionHelperBase { ...@@ -787,9 +787,9 @@ class FusionMergePassHelper : public FusionHelperBase {
} }
} }
fusionable_consumers.clear(); fusionable_consumers->clear();
if (candidates.size()) { if (candidates.size()) {
fusionable_consumers.insert(*candidates.begin()); fusionable_consumers->insert(*candidates.begin());
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册