未验证 提交 e2009545 编写于 作者: Z zyfncg 提交者: GitHub

[CINN] Adjust the code format in cinn (#55009)

* adjust the code format in cinn

* fix merge conflict
上级 e5725680
......@@ -38,7 +38,7 @@ TEST(Decomposer, relu) {
std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{20, 10}};
RunAndCheck<float>(
builder, input_names, output_names, output_shapes, relu_cpu, -1, 1);
&builder, input_names, output_names, output_shapes, relu_cpu, -1, 1);
}
TEST(Decomposer, relu_grad) {
......@@ -62,7 +62,7 @@ TEST(Decomposer, relu_grad) {
std::vector<std::string> output_names = {dx->id};
std::vector<std::vector<int>> output_shapes = {{20, 10}};
RunAndCheck<float>(
builder, input_names, output_names, output_shapes, relu_grad_cpu, -1, 1);
&builder, input_names, output_names, output_shapes, relu_grad_cpu, -1, 1);
}
TEST(Decomposer, softmax_decomposer) {
......
......@@ -27,7 +27,7 @@ TEST(Decomposer, elementwise_add_bcast0) {
std::vector<std::string> input_names = {x.id().data(), y.id().data()};
std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{4, 10, 20, 10}};
RunAndCheckShape<float>(builder, input_names, output_names, output_shapes);
RunAndCheckShape<float>(&builder, input_names, output_names, output_shapes);
}
TEST(Decomposer, elementwise_add_bcase1) {
......@@ -39,7 +39,7 @@ TEST(Decomposer, elementwise_add_bcase1) {
std::vector<std::string> input_names = {x.id().data(), y.id().data()};
std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{4, 10, 20, 10}};
RunAndCheckShape<float>(builder, input_names, output_names, output_shapes);
RunAndCheckShape<float>(&builder, input_names, output_names, output_shapes);
}
TEST(Decomposer, elementwise_add_grad_bcast0) {
......@@ -52,7 +52,7 @@ TEST(Decomposer, elementwise_add_grad_bcast0) {
std::vector<std::string> input_names = {dout.id().data()};
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
std::vector<std::vector<int>> output_shapes = {{4, 1, 20, 10}, {10, 20}};
RunAndCheckShape<float>(builder, input_names, output_names, output_shapes);
RunAndCheckShape<float>(&builder, input_names, output_names, output_shapes);
}
TEST(Decomposer, elementwise_add_bcast1) {
......@@ -80,7 +80,7 @@ TEST(Decomposer, elementwise_add_bcast1) {
std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}};
RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_cpu);
&builder, input_names, output_names, output_shapes, add_cpu);
}
TEST(Decomposer, elementwise_add_bcast1_2) {
......@@ -108,7 +108,7 @@ TEST(Decomposer, elementwise_add_bcast1_2) {
std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}};
RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_cpu);
&builder, input_names, output_names, output_shapes, add_cpu);
}
TEST(Decomposer, elementwise_add_grad_bcast1) {
......@@ -140,7 +140,7 @@ TEST(Decomposer, elementwise_add_grad_bcast1) {
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
std::vector<std::vector<int>> output_shapes = {{32, 64, 32, 32}, {64}};
RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_grad_cpu);
&builder, input_names, output_names, output_shapes, add_grad_cpu);
}
TEST(Decomposer, elementwise_add_bcast2) {
......@@ -165,7 +165,7 @@ TEST(Decomposer, elementwise_add_bcast2) {
std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}};
RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_cpu);
&builder, input_names, output_names, output_shapes, add_cpu);
}
TEST(Decomposer, elementwise_add_bcast2_2) {
......@@ -190,7 +190,7 @@ TEST(Decomposer, elementwise_add_bcast2_2) {
std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}};
RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_cpu);
&builder, input_names, output_names, output_shapes, add_cpu);
}
TEST(Decomposer, elementwise_add_bcast2_3) {
......@@ -217,7 +217,7 @@ TEST(Decomposer, elementwise_add_bcast2_3) {
std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}};
RunAndCheck<int_ty>(
builder, input_names, output_names, output_shapes, add_cpu);
&builder, input_names, output_names, output_shapes, add_cpu);
}
TEST(Decomposer, elementwise_add_grad_bcast2) {
......@@ -244,7 +244,7 @@ TEST(Decomposer, elementwise_add_grad_bcast2) {
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}, {1}};
RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_grad_cpu);
&builder, input_names, output_names, output_shapes, add_grad_cpu);
}
TEST(Decomposer, elementwise_add_same_dims) {
......@@ -268,7 +268,7 @@ TEST(Decomposer, elementwise_add_same_dims) {
std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}};
RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_cpu);
&builder, input_names, output_names, output_shapes, add_cpu);
}
TEST(Decomposer, elementwise_add_grad_same_dims) {
......@@ -295,7 +295,7 @@ TEST(Decomposer, elementwise_add_grad_same_dims) {
std::vector<std::string> output_names = {out_grads[0]->id, out_grads[1]->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}, {32, 16}};
RunAndCheck<float>(
builder, input_names, output_names, output_shapes, add_grad_cpu);
&builder, input_names, output_names, output_shapes, add_grad_cpu);
}
} // namespace cinn::frontend
......@@ -42,7 +42,7 @@ TEST(Decomposer, sum) {
std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}};
RunAndCheck<float>(
builder, input_names, output_names, output_shapes, sum_cpu);
&builder, input_names, output_names, output_shapes, sum_cpu);
}
} // namespace cinn::frontend
......@@ -193,7 +193,7 @@ void RunDecomposer(Program* prog,
const std::vector<std::string>& fetch_ids = {});
template <typename T>
void RunAndCheckShape(NetBuilder& builder,
void RunAndCheckShape(NetBuilder* builder,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int>>& output_shapes,
......@@ -202,7 +202,7 @@ void RunAndCheckShape(NetBuilder& builder,
T low = 0,
T high = 1,
const std::vector<std::string>& passes = {"Decomposer"}) {
auto prog = builder.Build();
auto prog = builder->Build();
Target target = common::DefaultTarget();
RunDecomposer(&prog, target, passes, output_names);
auto graph = std::make_shared<hlir::framework::Graph>(prog, target);
......@@ -238,7 +238,7 @@ void RunAndCheckShape(NetBuilder& builder,
}
template <typename T>
void RunAndCheck(NetBuilder& builder,
void RunAndCheck(NetBuilder* builder,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::vector<int>>& output_shapes,
......
......@@ -42,7 +42,7 @@ TEST(FillConstantRewriter, remove_reshape_single) {
std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"};
int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names);
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 2);
}
......@@ -68,7 +68,7 @@ TEST(FillConstantRewriter, remove_reshape_with_fill_constant) {
std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"};
int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names);
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 2);
}
......@@ -93,7 +93,7 @@ TEST(FillConstantRewriter, remove_scale_single) {
std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"};
int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names);
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 2);
}
......@@ -118,7 +118,7 @@ TEST(FillConstantRewriter, remove_scale_with_fill_constant) {
std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"};
int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names);
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 2);
}
......@@ -150,7 +150,7 @@ TEST(FillConstantRewriter, remove_multi_scale_with_fill_constant) {
std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"};
int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names);
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 4);
}
......@@ -167,7 +167,7 @@ TEST(FillConstantRewriter, two_fill_constant) {
std::vector<std::string> program_passes = {"FillConstantRewriter",
"RemoveIdentity"};
int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names);
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 0);
}
......
......@@ -40,7 +40,7 @@ TEST(RemoveIdentity, remove_single) {
std::vector<std::string> program_passes = {"RemoveIdentity",
"DeadCodeEliminate"};
int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names);
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 3);
}
......@@ -63,7 +63,7 @@ TEST(RemoveIdentity, remove_branch) {
std::vector<std::string> output_names = {reduce_sum_1->id, reduce_sum_2->id};
std::vector<std::string> program_passes = {"RemoveIdentity"};
int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names);
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 1);
}
......@@ -92,7 +92,7 @@ TEST(RemoveIdentity, remove_multiple) {
std::vector<std::string> output_names = {mul_1->id};
std::vector<std::string> program_passes = {"RemoveIdentity"};
int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names);
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 3);
}
......@@ -121,7 +121,7 @@ TEST(RemoveIdentity, cannot_remove_fetch) {
std::vector<std::string> output_names = {identity_2->id, mul_1->id};
std::vector<std::string> program_passes = {"RemoveIdentity"};
int num_removed_ops =
tester.RunAndCheck(builder, program_passes, input_names, output_names);
tester.RunAndCheck(&builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 1);
}
......
......@@ -75,11 +75,11 @@ class PassTest {
public:
PassTest() { target_ = common::DefaultTarget(); }
int RunAndCheck(NetBuilder& builder,
int RunAndCheck(NetBuilder* builder,
const std::vector<std::string>& program_passes,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names) {
auto program = builder.Build();
auto program = builder->Build();
CHECK(IsValid(program)) << "The origin program is not valid.";
int origin_program_size = program.size();
LOG(INFO) << "Run origin program";
......
......@@ -176,7 +176,7 @@ class FusionMergePassHelper : public FusionHelperBase {
bool HorizontalFusion(
GroupPtr producer,
std::unordered_set<GroupPtr, Hasher, Comparator>& consumers) {
const std::unordered_set<GroupPtr, Hasher, Comparator>& consumers) {
VLOG(3) << "HorizontalFusion...!";
if (consumers.size() <= 1) {
return false;
......@@ -249,7 +249,7 @@ class FusionMergePassHelper : public FusionHelperBase {
return updated;
}
void HorizontalFuse(GroupList& consumers) {
void HorizontalFuse(const GroupList& consumers) {
VLOG(3) << "HorizontalFuse Groups...";
// create fusion group
auto fused_group = std::make_shared<Graph::Group>();
......@@ -400,8 +400,8 @@ class FusionMergePassHelper : public FusionHelperBase {
}
bool VerticalFusion(
GroupPtr& producer,
std::unordered_set<GroupPtr, Hasher, Comparator>& consumers,
const GroupPtr& producer,
const std::unordered_set<GroupPtr, Hasher, Comparator>& consumers,
bool recompute) {
VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size();
auto& relation = fusion_relation_map_[producer->op_pattern_kind];
......@@ -463,14 +463,14 @@ class FusionMergePassHelper : public FusionHelperBase {
if (!recompute) {
return false;
} else {
RecomputeEleGraph(producer, fuse_consumers_unsafe);
RecomputeEleGraph(producer, &fuse_consumers_unsafe);
VerticalFuse(producer, fuse_consumers_unsafe);
return true;
}
}
if (fuse_consumers.size()) {
SelectConsumerToFuse(producer, fuse_consumers);
SelectConsumerToFuse(producer, &fuse_consumers);
}
// if fusionable consumers exist
......@@ -482,9 +482,9 @@ class FusionMergePassHelper : public FusionHelperBase {
return false;
}
void VerticalFuse(
GroupPtr& producer,
std::unordered_set<GroupPtr, Hasher, Comparator>& fusionable_consumers) {
void VerticalFuse(const GroupPtr& producer,
const std::unordered_set<GroupPtr, Hasher, Comparator>&
fusionable_consumers) {
VLOG(3) << "VerticalFuse...!";
GroupList fused_groups;
GroupPtr master_fuesd_group(nullptr);
......@@ -671,7 +671,7 @@ class FusionMergePassHelper : public FusionHelperBase {
void RecomputeEleGraph(
const GroupPtr& producer,
std::unordered_set<GroupPtr, Hasher, Comparator>& fusionable_consumers) {
std::unordered_set<GroupPtr, Hasher, Comparator>* fusionable_consumers) {
if (producer->op_pattern_kind != framework::kElementWise) {
SelectConsumerToFuse(producer, fusionable_consumers);
}
......@@ -679,11 +679,11 @@ class FusionMergePassHelper : public FusionHelperBase {
void SelectConsumerToFuse(
const GroupPtr& producer,
std::unordered_set<GroupPtr, Hasher, Comparator>& fusionable_consumers) {
std::unordered_set<GroupPtr, Hasher, Comparator>* fusionable_consumers) {
// if is const op
if (is_const_group(this, producer)) {
std::unordered_set<GroupPtr, Hasher, Comparator> candidates;
for (auto& consumer : fusionable_consumers) {
for (auto& consumer : *fusionable_consumers) {
// if can be output node.
if (is_same_shape(this, producer, consumer)) {
candidates.insert(consumer);
......@@ -707,10 +707,10 @@ class FusionMergePassHelper : public FusionHelperBase {
CHECK_GE(producer->consumer_groups.size(), candidates.size());
if (producer->consumer_groups.size() == 0 && candidates.size() == 0 &&
output_nodes_set_.count(producer->CollectNodes()[0]) == 0) {
producer->belong_groups.insert(*fusionable_consumers.begin());
producer->belong_groups.insert(*fusionable_consumers->begin());
}
fusionable_consumers = candidates;
*fusionable_consumers = candidates;
return;
}
// 1 to 1 fusion.
......@@ -720,7 +720,7 @@ class FusionMergePassHelper : public FusionHelperBase {
if (FLAGS_enhance_vertical_fusion_with_recompute) {
std::vector<GroupPtr> candidates;
for (auto& consumer : fusionable_consumers) {
for (auto& consumer : *fusionable_consumers) {
if (consumer->op_pattern_kind == framework::kElementWise) {
candidates.push_back(consumer);
continue;
......@@ -764,13 +764,13 @@ class FusionMergePassHelper : public FusionHelperBase {
return lhs->op_pattern_kind < rhs->op_pattern_kind;
});
fusionable_consumers.clear();
fusionable_consumers->clear();
if (candidates.size()) {
fusionable_consumers.insert(*candidates.begin());
fusionable_consumers->insert(*candidates.begin());
}
} else {
std::unordered_set<GroupPtr, Hasher, Comparator> candidates;
for (auto& consumer : fusionable_consumers) {
for (auto& consumer : *fusionable_consumers) {
if (consumer->op_pattern_kind == framework::kElementWise) {
candidates.insert(consumer);
continue;
......@@ -787,9 +787,9 @@ class FusionMergePassHelper : public FusionHelperBase {
}
}
fusionable_consumers.clear();
fusionable_consumers->clear();
if (candidates.size()) {
fusionable_consumers.insert(*candidates.begin());
fusionable_consumers->insert(*candidates.begin());
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册