未验证 提交 70183c4b 编写于 作者: H Huihuang Zheng 提交者: GitHub

Remove Old Schedules in Ops (#55391)

Remove old schedules.
上级 db1f2c42
...@@ -60,37 +60,34 @@ std::shared_ptr<OpStrategy> StrategyForBroadcast( ...@@ -60,37 +60,34 @@ std::shared_ptr<OpStrategy> StrategyForBroadcast(
const ir::Tensor &B, const ir::Tensor &B,
const std::string &output_name, const std::string &output_name,
const Expr &axis)) { const Expr &axis)) {
framework::CINNCompute binary_compute([=](lang::Args args, framework::CINNCompute binary_compute(
lang::RetValue *ret) { [=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of " << op_name CHECK(!args.empty()) << "The input argument of " << op_name
<< " compute is empty! Please check."; << " compute is empty! Please check.";
CINNValuePack pack_args = args[0]; CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 2U) CHECK_GE(pack_args.size(), 2U)
<< "at least 2 input tensors for " << op_name << " compute"; << "at least 2 input tensors for " << op_name << " compute";
std::string tensor_name = UniqName(op_name + "_Out"); CHECK_GE(pack_args.size(), 3U) << op_name << " 's input is not enough!";
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[2].is_string());
CHECK_GE(pack_args.size(), 3U) << op_name << " 's input is not enough!"; std::string tensor_name = pack_args[2].operator std::string();
CHECK(pack_args[2].is_string()); Expr A_expr = pack_args[0];
tensor_name = pack_args[2].operator std::string(); Expr B_expr = pack_args[1];
} CHECK(A_expr.as_tensor());
Expr A_expr = pack_args[0]; CHECK(B_expr.as_tensor());
Expr B_expr = pack_args[1]; ir::Tensor A = A_expr.as_tensor_ref();
CHECK(A_expr.as_tensor()); ir::Tensor B = B_expr.as_tensor_ref();
CHECK(B_expr.as_tensor()); Expr axis;
ir::Tensor A = A_expr.as_tensor_ref(); bool trans_a;
ir::Tensor B = B_expr.as_tensor_ref(); for (auto &iter : attrs.attr_store) {
Expr axis; if (iter.first == "axis") {
bool trans_a; axis = Expr(absl::get<int>(iter.second));
for (auto &iter : attrs.attr_store) { break;
if (iter.first == "axis") { }
axis = Expr(absl::get<int>(iter.second)); }
break; auto out = pe_func(A, B, tensor_name, axis);
} auto stages = CreateStages({A, B, out});
} *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
auto out = pe_func(A, B, tensor_name, axis); });
auto stages = CreateStages({A, B, out});
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(binary_compute, strategy->AddImpl(binary_compute,
...@@ -198,12 +195,10 @@ std::shared_ptr<OpStrategy> StrategyForBroadcastTo( ...@@ -198,12 +195,10 @@ std::shared_ptr<OpStrategy> StrategyForBroadcastTo(
CINNValuePack pack_args = args[0]; CINNValuePack pack_args = args[0];
CHECK(!pack_args.empty()) CHECK(!pack_args.empty())
<< "The input tensors of broadcast_to compute is empty! Please check."; << "The input tensors of broadcast_to compute is empty! Please check.";
std::string tensor_name = UniqName("broadcast_to_Out"); CHECK_GE(pack_args.size(), 2U);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_GE(pack_args.size(), 2U); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
Expr A_expr = pack_args[0]; Expr A_expr = pack_args[0];
CHECK(A_expr.as_tensor()); CHECK(A_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref(); ir::Tensor A = A_expr.as_tensor_ref();
...@@ -323,12 +318,9 @@ std::shared_ptr<OpStrategy> StrategyForIsClose( ...@@ -323,12 +318,9 @@ std::shared_ptr<OpStrategy> StrategyForIsClose(
CINNValuePack pack_args = args[0]; CINNValuePack pack_args = args[0];
int input_size = pack_args.size(); int input_size = pack_args.size();
std::string tensor_name = UniqName("IsClose_output"); // the last pack argument is the output tensor name
if (FLAGS_cinn_ir_schedule) { std::string tensor_name = pack_args.back().operator std::string();
// the last pack argument is the output tensor name --input_size;
tensor_name = pack_args.back().operator std::string();
--input_size;
}
CHECK_EQ(input_size, 2) CHECK_EQ(input_size, 2)
<< "The input number of isclose should be 2, but here " << "The input number of isclose should be 2, but here "
<< input_size << "! Please check."; << input_size << "! Please check.";
......
...@@ -114,11 +114,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd( ...@@ -114,11 +114,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
VLOG(3) << "x shape: " << utils::Join(tensor_x->shape, ", ") VLOG(3) << "x shape: " << utils::Join(tensor_x->shape, ", ")
<< ", index shape: " << utils::Join(tensor_index->shape, ", ") << ", index shape: " << utils::Join(tensor_index->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", "); << ", output_shapes: " << utils::Join(output_shapes[0], ", ");
std::string tensor_name = UniqName("GatherNd_out"); CHECK_EQ(pack_args.size(), 3U);
if (FLAGS_cinn_ir_schedule) { std::string tensor_name = pack_args[2].operator std::string();
CHECK_EQ(pack_args.size(), 3U);
tensor_name = pack_args[2].operator std::string();
}
ir::Tensor out = GatherNd(tensor_x, tensor_index, tensor_name); ir::Tensor out = GatherNd(tensor_x, tensor_index, tensor_name);
std::vector<CINNValue> res; std::vector<CINNValue> res;
stages->InsertLazily(out); stages->InsertLazily(out);
...@@ -131,44 +128,34 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd( ...@@ -131,44 +128,34 @@ std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
framework::CINNSchedule gather_nd_schedule([=](lang::Args args, framework::CINNSchedule gather_nd_schedule([=](lang::Args args,
lang::RetValue *ret) { lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty()) << "The input argument of gather_nd_schedule is "
CHECK(!args.empty()) << "The input argument of gather_nd_schedule is " "empty! Please check.\n";
"empty! Please check.\n"; common::CINNValuePack arg_pack = args[0];
common::CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp);
}
} }
CHECK(!vec_ast.empty()); }
ir::ModuleExpr mod_expr(vec_ast); CHECK(!vec_ast.empty());
ir::IRSchedule ir_sch(mod_expr); ir::ModuleExpr mod_expr(vec_ast);
ir_sch.MergeExprs(); ir::IRSchedule ir_sch(mod_expr);
int64_t prod_size = std::accumulate(output_shapes[0].begin(), ir_sch.MergeExprs();
output_shapes[0].end(), int64_t prod_size = std::accumulate(output_shapes[0].begin(),
1, output_shapes[0].end(),
std::multiplies<int>()); 1,
if (prod_size > 1) { std::multiplies<int>());
if (target.arch == Target::Arch::NVGPU) { if (prod_size > 1) {
pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target); if (target.arch == Target::Arch::NVGPU) {
} else if (target.arch == Target::Arch::X86) { pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target);
pe::IRScheduleInjectiveCPU( } else if (target.arch == Target::Arch::X86) {
ir_sch, output_shapes.front(), target, true); pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true);
}
} }
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty()) << "The input argument of gather_nd_schedule is "
"empty! Please check.\n";
CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
} }
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
}); });
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
......
...@@ -105,12 +105,8 @@ std::shared_ptr<OpStrategy> StrategyForLogicalRightShift( ...@@ -105,12 +105,8 @@ std::shared_ptr<OpStrategy> StrategyForLogicalRightShift(
ir::Tensor A = A_expr.as_tensor_ref(); ir::Tensor A = A_expr.as_tensor_ref();
ir::Tensor B = B_expr.as_tensor_ref(); ir::Tensor B = B_expr.as_tensor_ref();
std::string tensor_name = UniqName("T_LogicalRightShift_out"); CHECK_EQ(pack_args.size(), 3U);
std::string tensor_name = pack_args[2].operator std::string();
if (FLAGS_cinn_ir_schedule) {
CHECK_EQ(pack_args.size(), 3U);
tensor_name = pack_args[2].operator std::string();
}
auto out = LogicalRightShift(A, B, target, tensor_name); auto out = LogicalRightShift(A, B, target, tensor_name);
auto stages = CreateStages({out}); auto stages = CreateStages({out});
......
...@@ -106,11 +106,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForLookupTable( ...@@ -106,11 +106,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForLookupTable(
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", B shape: " << utils::Join(tensor_B->shape, ", ") << ", B shape: " << utils::Join(tensor_B->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", "); << ", output_shapes: " << utils::Join(output_shapes[0], ", ");
std::string tensor_name = UniqName("LookupTable_out"); CHECK_EQ(pack_args.size(), 3U);
if (FLAGS_cinn_ir_schedule) { std::string tensor_name = pack_args[2].operator std::string();
CHECK_EQ(pack_args.size(), 3U);
tensor_name = pack_args[2].operator std::string();
}
ir::Tensor out = LookupTable(tensor_A, tensor_B, padding_idx, tensor_name); ir::Tensor out = LookupTable(tensor_A, tensor_B, padding_idx, tensor_name);
std::vector<CINNValue> res; std::vector<CINNValue> res;
stages->InsertLazily(out); stages->InsertLazily(out);
......
...@@ -194,12 +194,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForOneHot( ...@@ -194,12 +194,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForOneHot(
ir::Tensor on_value = on_value_expr.as_tensor_ref(); ir::Tensor on_value = on_value_expr.as_tensor_ref();
ir::Tensor off_value = off_value_expr.as_tensor_ref(); ir::Tensor off_value = off_value_expr.as_tensor_ref();
std::string tensor_name = common::UniqName("T_OneHot_out"); CHECK_EQ(pack_args.size(), 4U);
std::string tensor_name = pack_args[3].operator std::string();
if (FLAGS_cinn_ir_schedule) {
CHECK_EQ(pack_args.size(), 4U);
tensor_name = pack_args[3].operator std::string();
}
ir::Tensor out = OneHot(indices, ir::Tensor out = OneHot(indices,
on_value, on_value,
......
...@@ -94,13 +94,9 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal( ...@@ -94,13 +94,9 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
CHECK(!pack_args.empty()) CHECK(!pack_args.empty())
<< "at least one input tensor for " << op_name << " compute\n"; << "at least one input tensor for " << op_name << " compute\n";
std::string tensor_name = UniqName("Reciprocal_out"); CHECK_EQ(pack_args.size(), 2);
CHECK(pack_args[1].is_string());
if (FLAGS_cinn_ir_schedule) { std::string tensor_name = pack_args[1].operator std::string();
CHECK_EQ(pack_args.size(), 2);
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
Expr A = pack_args[0]; Expr A = pack_args[0];
CHECK(A.as_tensor()); CHECK(A.as_tensor());
...@@ -110,10 +106,8 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal( ...@@ -110,10 +106,8 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", "); << ", output_shapes: " << utils::Join(output_shapes[0], ", ");
if (FLAGS_cinn_ir_schedule) { CHECK_EQ(pack_args.size(), 2U);
CHECK_EQ(pack_args.size(), 2U); tensor_name = pack_args[1].operator std::string();
tensor_name = pack_args[1].operator std::string();
}
ir::Tensor out = Reciprocal(tensor_A, tensor_name); ir::Tensor out = Reciprocal(tensor_A, tensor_name);
std::vector<CINNValue> res; std::vector<CINNValue> res;
......
...@@ -207,12 +207,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForResize( ...@@ -207,12 +207,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForResize(
auto tensor_A = A.as_tensor_ref(); auto tensor_A = A.as_tensor_ref();
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", "); << ", output_shapes: " << utils::Join(output_shapes[0], ", ");
std::string tensor_name = common::UniqName("T_Resize_out");
if (FLAGS_cinn_ir_schedule) { CHECK_EQ(pack_args.size(), 2U);
CHECK_EQ(pack_args.size(), 2U); std::string tensor_name = pack_args[1].operator std::string();
tensor_name = pack_args[1].operator std::string();
}
ir::Tensor out = Resize(tensor_A, target, out_shape, mode, tensor_name); ir::Tensor out = Resize(tensor_A, target, out_shape, mode, tensor_name);
......
...@@ -178,12 +178,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort( ...@@ -178,12 +178,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
auto stages = CreateStages({tensor_A}); auto stages = CreateStages({tensor_A});
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", "); << ", output_shapes: " << utils::Join(output_shapes[0], ", ");
auto tensor_name = UniqName("Sort_out"); CHECK_EQ(pack_args.size(), 2U);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 2U); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
std::vector<ir::Tensor> out = std::vector<ir::Tensor> out =
Sort(tensor_A, target, stages, axis, is_ascend, tensor_name); Sort(tensor_A, target, stages, axis, is_ascend, tensor_name);
stages->InsertLazily(out[0]); stages->InsertLazily(out[0]);
...@@ -195,48 +192,40 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort( ...@@ -195,48 +192,40 @@ std::shared_ptr<framework::OpStrategy> StrategyForSort(
*ret = CINNValuePack{res}; *ret = CINNValuePack{res};
}); });
framework::CINNSchedule sort_schedule([=](lang::Args args, framework::CINNSchedule sort_schedule(
lang::RetValue *ret) { [=](lang::Args args, lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty())
CHECK(!args.empty()) << "The input argument of sort_schedule is empty! Please check.\n";
<< "The input argument of sort_schedule is empty! Please check.\n"; common::CINNValuePack arg_pack = args[0];
common::CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp); }
} }
} CHECK(!vec_ast.empty());
CHECK(!vec_ast.empty()); ir::ModuleExpr mod_expr(vec_ast);
ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr);
ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs();
ir_sch.MergeExprs(); auto blocks = ir_sch.GetAllBlocks();
auto blocks = ir_sch.GetAllBlocks(); // TODO(Shixiaowei02): remove external calls, do not use local
// TODO(Shixiaowei02): remove external calls, do not use local variables, // variables, because the size will exceed the limit.
// because the size will exceed the limit. ir_sch.SetBuffer(blocks[0], "local");
ir_sch.SetBuffer(blocks[0], "local"); ir_sch.SetBuffer(blocks[1], "local");
ir_sch.SetBuffer(blocks[1], "local");
int64_t prod_size = std::accumulate(output_shapes[0].begin(),
int64_t prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(),
output_shapes[0].end(), 1,
1, std::multiplies<int>());
std::multiplies<int>()); if (prod_size > 1 && target.arch == Target::Arch::X86) {
if (prod_size > 1 && target.arch == Target::Arch::X86) { pe::IRScheduleInjectiveCPU(
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); ir_sch, output_shapes.front(), target, true);
} }
std::vector<common::CINNValue> res{ std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res}; *ret = common::CINNValuePack{res};
} else { });
CHECK(!args.empty())
<< "The input argument of sort_schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
}
});
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(sort_compute, sort_schedule, "strategy.sort", 1); strategy->AddImpl(sort_compute, sort_schedule, "strategy.sort", 1);
...@@ -271,12 +260,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort( ...@@ -271,12 +260,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
auto stages = CreateStages({tensor_A}); auto stages = CreateStages({tensor_A});
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", "); << ", output_shapes: " << utils::Join(output_shapes[0], ", ");
auto tensor_name = UniqName("ArgSort_out"); CHECK_EQ(pack_args.size(), 3U);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 3U); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
auto out = ArgSort(tensor_A, target, stages, axis, is_ascend, tensor_name); auto out = ArgSort(tensor_A, target, stages, axis, is_ascend, tensor_name);
std::vector<CINNValue> res; std::vector<CINNValue> res;
stages->InsertLazily(out.at(0)); stages->InsertLazily(out.at(0));
...@@ -291,45 +277,36 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort( ...@@ -291,45 +277,36 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgSort(
framework::CINNSchedule argsort_schedule([=](lang::Args args, framework::CINNSchedule argsort_schedule([=](lang::Args args,
lang::RetValue *ret) { lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty())
CHECK(!args.empty()) << "The input argument of argsort_schedule is empty! Please check.\n";
<< "The input argument of argsort_schedule is empty! Please check.\n"; common::CINNValuePack arg_pack = args[0];
common::CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
auto blocks = ir_sch.GetAllBlocks();
// TODO(Shixiaowei02): remove external calls, do not use local variables,
// because the size will exceed the limit.
// TODO(lanxianghit): There is a bug, setting buffer to "local" here will
// cause the var declared twice at CodeGen. ir_sch.SetBuffer(blocks[0],
// "local");
int64_t prod_size = std::accumulate(output_shapes[0].begin(),
output_shapes[0].end(),
1,
std::multiplies<int>());
if (prod_size > 1 && target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true);
} }
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty())
<< "The input argument of argsort_schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
} }
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
auto blocks = ir_sch.GetAllBlocks();
// TODO(Shixiaowei02): remove external calls, do not use local variables,
// because the size will exceed the limit.
// TODO(lanxianghit): There is a bug, setting buffer to "local" here will
// cause the var declared twice at CodeGen. ir_sch.SetBuffer(blocks[0],
// "local");
int64_t prod_size = std::accumulate(output_shapes[0].begin(),
output_shapes[0].end(),
1,
std::multiplies<int>());
if (prod_size > 1 && target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true);
}
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
}); });
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
......
...@@ -67,12 +67,9 @@ std::shared_ptr<OpStrategy> StrategyForElementwise( ...@@ -67,12 +67,9 @@ std::shared_ptr<OpStrategy> StrategyForElementwise(
CINNValuePack pack_args = args[0]; CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 1U) CHECK_GE(pack_args.size(), 1U)
<< "1 input tensor for " << op_name << " compute"; << "1 input tensor for " << op_name << " compute";
std::string tensor_name = UniqName(op_name + "_Out"); CHECK_EQ(pack_args.size(), 2U);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 2U); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
Expr A_expr = pack_args[0]; Expr A_expr = pack_args[0];
CHECK(A_expr.as_tensor()); CHECK(A_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref(); ir::Tensor A = A_expr.as_tensor_ref();
...@@ -158,12 +155,9 @@ std::shared_ptr<OpStrategy> StrategyForScale( ...@@ -158,12 +155,9 @@ std::shared_ptr<OpStrategy> StrategyForScale(
CHECK(A_expr.as_tensor()); CHECK(A_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref(); ir::Tensor A = A_expr.as_tensor_ref();
ir::Tensor out; ir::Tensor out;
std::string tensor_name = UniqName("Scale_out"); CHECK_EQ(pack_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 2); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
if (bias_after_scale) { if (bias_after_scale) {
out = Compute( out = Compute(
...@@ -242,12 +236,9 @@ std::shared_ptr<OpStrategy> StrategyForConstScalar( ...@@ -242,12 +236,9 @@ std::shared_ptr<OpStrategy> StrategyForConstScalar(
auto scalar = GetScalarExpr(attrs.attr_store.at("value")); auto scalar = GetScalarExpr(attrs.attr_store.at("value"));
auto scalar_type = out_type.at(0); auto scalar_type = out_type.at(0);
CINNValuePack pack_args = args[0]; CINNValuePack pack_args = args[0];
std::string tensor_name = UniqName("const_scalar_Out"); CHECK_EQ(pack_args.size(), 1U);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[0].is_string());
CHECK_EQ(pack_args.size(), 1U); std::string tensor_name = pack_args[0].operator std::string();
CHECK(pack_args[0].is_string());
tensor_name = pack_args[0].operator std::string();
}
auto out = lang::Compute( auto out = lang::Compute(
{Expr(1)}, {Expr(1)},
...@@ -371,12 +362,9 @@ std::shared_ptr<OpStrategy> StrategyForFillConstant( ...@@ -371,12 +362,9 @@ std::shared_ptr<OpStrategy> StrategyForFillConstant(
} }
CINNValuePack arg_pack = args[0]; CINNValuePack arg_pack = args[0];
std::string tensor_name = UniqName("fill_constant_Out"); CHECK_EQ(arg_pack.size(), 1U);
if (FLAGS_cinn_ir_schedule) { CHECK(arg_pack[0].is_string());
CHECK_EQ(arg_pack.size(), 1U); std::string tensor_name = arg_pack[0].operator std::string();
CHECK(arg_pack[0].is_string());
tensor_name = arg_pack[0].operator std::string();
}
CHECK(!shape.empty()) << "shape attr is empty!"; CHECK(!shape.empty()) << "shape attr is empty!";
auto shape_exprs = ToCinnExprs(shape); auto shape_exprs = ToCinnExprs(shape);
auto out = lang::Compute( auto out = lang::Compute(
...@@ -458,12 +446,9 @@ std::shared_ptr<OpStrategy> StrategyForAssignValue( ...@@ -458,12 +446,9 @@ std::shared_ptr<OpStrategy> StrategyForAssignValue(
const auto &value = attrs.attr_store.at("values"); const auto &value = attrs.attr_store.at("values");
CINNValuePack arg_pack = args[0]; CINNValuePack arg_pack = args[0];
std::string tensor_name = UniqName("T_assign_value_out"); CHECK_EQ(arg_pack.size(), 1U);
if (FLAGS_cinn_ir_schedule) { CHECK(arg_pack[0].is_string());
CHECK_EQ(arg_pack.size(), 1U); std::string tensor_name = arg_pack[0].operator std::string();
CHECK(arg_pack[0].is_string());
tensor_name = arg_pack[0].operator std::string();
}
absl::optional<ir::Tensor> out; absl::optional<ir::Tensor> out;
#define EXPAND_VALUE_TO_TENSOR(TYPE) \ #define EXPAND_VALUE_TO_TENSOR(TYPE) \
...@@ -649,11 +634,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForSqueeze( ...@@ -649,11 +634,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForSqueeze(
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", "); << ", output_shapes: " << utils::Join(output_shapes[0], ", ");
std::string tensor_name = UniqName("Squeeze_out"); CHECK_EQ(pack_args.size(), 2U);
if (FLAGS_cinn_ir_schedule) { std::string tensor_name = pack_args[1].operator std::string();
CHECK_EQ(pack_args.size(), 2U);
tensor_name = pack_args[1].operator std::string();
}
ir::Tensor out = pe::Squeeze(tensor_A, axes, tensor_name); ir::Tensor out = pe::Squeeze(tensor_A, axes, tensor_name);
std::vector<CINNValue> res; std::vector<CINNValue> res;
...@@ -729,12 +711,9 @@ std::shared_ptr<OpStrategy> StrategyForExpandDims( ...@@ -729,12 +711,9 @@ std::shared_ptr<OpStrategy> StrategyForExpandDims(
Expr x = input_args[0]; Expr x = input_args[0];
CHECK(x.as_tensor()); CHECK(x.as_tensor());
std::string tensor_name = UniqName("expand_dims_output"); CHECK_EQ(input_args.size(), 2U);
if (FLAGS_cinn_ir_schedule) { CHECK(input_args[1].is_string());
CHECK_EQ(input_args.size(), 2U); std::string tensor_name = input_args[1].operator std::string();
CHECK(input_args[1].is_string());
tensor_name = input_args[1].operator std::string();
}
auto out = auto out =
pe::ExpandDims(x.as_tensor_ref(), axes, output_shapes[0], tensor_name); pe::ExpandDims(x.as_tensor_ref(), axes, output_shapes[0], tensor_name);
...@@ -809,12 +788,9 @@ std::shared_ptr<OpStrategy> StrategyForReshape( ...@@ -809,12 +788,9 @@ std::shared_ptr<OpStrategy> StrategyForReshape(
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", "); << ", output_shapes: " << utils::Join(output_shapes[0], ", ");
std::string tensor_name = UniqName("Reshape_out"); CHECK_EQ(pack_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 2); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
ir::Tensor out = pe::Reshape(tensor_A, output_shapes[0], tensor_name); ir::Tensor out = pe::Reshape(tensor_A, output_shapes[0], tensor_name);
std::vector<CINNValue> res; std::vector<CINNValue> res;
...@@ -901,11 +877,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForCast( ...@@ -901,11 +877,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForCast(
auto stages = CreateStages({tensor_A}); auto stages = CreateStages({tensor_A});
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", "); << ", output_shapes: " << utils::Join(output_shapes[0], ", ");
std::string tensor_name = UniqName("Cast_out"); CHECK_EQ(pack_args.size(), 2U);
if (FLAGS_cinn_ir_schedule) { std::string tensor_name = pack_args[1].operator std::string();
CHECK_EQ(pack_args.size(), 2U);
tensor_name = pack_args[1].operator std::string();
}
ir::Tensor out = pe::Cast(tensor_A, out_type[0], tensor_name); ir::Tensor out = pe::Cast(tensor_A, out_type[0], tensor_name);
std::vector<CINNValue> res; std::vector<CINNValue> res;
stages->InsertLazily(out); stages->InsertLazily(out);
...@@ -953,11 +926,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForArange( ...@@ -953,11 +926,8 @@ std::shared_ptr<framework::OpStrategy> StrategyForArange(
<< "The input argument of arange compute is empty! Please check.\n"; << "The input argument of arange compute is empty! Please check.\n";
CINNValuePack pack_args = args[0]; CINNValuePack pack_args = args[0];
std::string tensor_name = common::UniqName("T_Arange_out"); CHECK_EQ(pack_args.size(), 1U);
if (FLAGS_cinn_ir_schedule) { std::string tensor_name = pack_args[0].operator std::string();
CHECK_EQ(pack_args.size(), 1U);
tensor_name = pack_args[0].operator std::string();
}
auto out = pe::Arange(start, stop, step, dtype, tensor_name); auto out = pe::Arange(start, stop, step, dtype, tensor_name);
std::vector<common::CINNValue> res; std::vector<common::CINNValue> res;
......
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
#include "paddle/cinn/ir/layout.h" #include "paddle/cinn/ir/layout.h"
#include "paddle/cinn/poly/stage.h" #include "paddle/cinn/poly/stage.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn { namespace cinn {
namespace hlir { namespace hlir {
namespace op { namespace op {
...@@ -55,12 +53,9 @@ std::shared_ptr<OpStrategy> StrategyForRelu( ...@@ -55,12 +53,9 @@ std::shared_ptr<OpStrategy> StrategyForRelu(
<< "at least one input tensor for relu compute\n"; << "at least one input tensor for relu compute\n";
Expr A = pack_args[0]; Expr A = pack_args[0];
CHECK(A.as_tensor()); CHECK(A.as_tensor());
std::string tensor_name = UniqName("Relu_output"); CHECK_EQ(pack_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 2); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
auto out = pe::Relu(A.as_tensor_ref(), 0.0, tensor_name); auto out = pe::Relu(A.as_tensor_ref(), 0.0, tensor_name);
auto stages = CreateStages({out}); auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
...@@ -107,12 +102,9 @@ std::shared_ptr<OpStrategy> StrategyForRelu6( ...@@ -107,12 +102,9 @@ std::shared_ptr<OpStrategy> StrategyForRelu6(
<< "at least one input tensor for relu6 compute\n"; << "at least one input tensor for relu6 compute\n";
Expr A = pack_args[0]; Expr A = pack_args[0];
CHECK(A.as_tensor()); CHECK(A.as_tensor());
std::string tensor_name = UniqName("Relu6_output"); CHECK_EQ(pack_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 2); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
auto out = pe::Relu6(A.as_tensor_ref(), 0.0, tensor_name); auto out = pe::Relu6(A.as_tensor_ref(), 0.0, tensor_name);
auto stages = CreateStages({out}); auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
...@@ -197,12 +189,9 @@ std::shared_ptr<OpStrategy> StrategyForConv2d( ...@@ -197,12 +189,9 @@ std::shared_ptr<OpStrategy> StrategyForConv2d(
<< utils::Join(A.as_tensor_ref()->shape, ", "); << utils::Join(A.as_tensor_ref()->shape, ", ");
VLOG(3) << "weight shape: " VLOG(3) << "weight shape: "
<< utils::Join(B.as_tensor_ref()->shape, ", "); << utils::Join(B.as_tensor_ref()->shape, ", ");
std::string tensor_name = UniqName("Conv2d_out"); CHECK_GE(pack_args.size(), 3);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[2].is_string());
CHECK_GE(pack_args.size(), 3); std::string tensor_name = pack_args[2].operator std::string();
CHECK(pack_args[2].is_string());
tensor_name = pack_args[2].operator std::string();
}
if (data_format == "NCHW") { if (data_format == "NCHW") {
// A is input: [N, C, H, W], B is filter: [C_out, C_in/group, // A is input: [N, C, H, W], B is filter: [C_out, C_in/group,
// filter_h, filter_w] // filter_h, filter_w]
...@@ -300,222 +289,51 @@ std::shared_ptr<OpStrategy> StrategyForConv2d( ...@@ -300,222 +289,51 @@ std::shared_ptr<OpStrategy> StrategyForConv2d(
framework::CINNSchedule conv2d_schedule([=](lang::Args args, framework::CINNSchedule conv2d_schedule([=](lang::Args args,
lang::RetValue *ret) { lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty())
CHECK(!args.empty()) << "The input argument of conv2d schedule is empty! Please check.\n";
<< "The input argument of conv2d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0];
CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp);
}
} }
CHECK(!vec_ast.empty()); }
ir::ModuleExpr mod_expr(vec_ast); CHECK(!vec_ast.empty());
ir::IRSchedule ir_sch(mod_expr); ir::ModuleExpr mod_expr(vec_ast);
ir_sch.MergeExprs(); ir::IRSchedule ir_sch(mod_expr);
if (target.arch == Target::Arch::NVGPU) { ir_sch.MergeExprs();
if (target.arch == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDNN #ifdef CINN_WITH_CUDNN
// If conv_type is backward_filter or backward_data, we built a fake op. // If conv_type is backward_filter or backward_data, we built a fake op.
// As runtime use cudnn to compute conv2d, this fake op is not to be // As runtime use cudnn to compute conv2d, this fake op is not to be
// called. When cinn support backward_filter/backward_data code gen, // called. When cinn support backward_filter/backward_data code gen,
// this code is to be removed. // this code is to be removed.
if (conv_type != "forward") { if (conv_type != "forward") {
CHECK_EQ(vec_ast.size(), 1); CHECK_EQ(vec_ast.size(), 1);
pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target); pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target);
std::vector<CINNValue> res{ std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))}; CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res}; *ret = CINNValuePack{res};
return; return;
}
#endif
int expr_size = vec_ast.size();
if (expr_size == 2) {
pe::IRCudaScheduleConv(ir_sch, target);
VLOG(3) << "After IRCudaScheduleConv, arg_pack[0] is : "
<< ir_sch.GetModule().GetExprs().at(0);
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
return;
} else {
CINN_NOT_IMPLEMENTED
}
} else if (target.arch == Target::Arch::X86) {
CINN_NOT_IMPLEMENTED
} }
LOG(FATAL) << "This target [" << target << "] is not supported yet.";
} else {
CHECK(!args.empty())
<< "The input argument of conv2d schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
CHECK(arg_pack.size() == 4UL || arg_pack.size() == 3UL ||
arg_pack.size() == 6UL || arg_pack.size() == 13UL);
poly::StageMap stages = arg_pack.back();
if (target.arch == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDNN
// If conv_type is backward_filter or backward_data, we built a fake op.
// As runtime use cudnn to compute conv2d, this fake op is not to be
// called. When cinn support backward_filter/backward_data code gen,
// this code is to be removed.
if (conv_type != "forward") {
Expr out = arg_pack[0];
pe::CudaScheduleInjective(
stages[out.as_tensor_ref()], output_shapes.front(), target);
*ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}};
return;
}
#endif #endif
if (arg_pack.size() == 4UL) { int expr_size = vec_ast.size();
Expr Out = arg_pack[0]; if (expr_size == 2) {
Expr input_pad = arg_pack[1]; pe::IRCudaScheduleConv(ir_sch, target);
Expr weights = arg_pack[2]; VLOG(3) << "After IRCudaScheduleConv, arg_pack[0] is : "
ir::Tensor out_t = Out.as_tensor_ref(); << ir_sch.GetModule().GetExprs().at(0);
ir::Tensor input_t = input_pad.as_tensor_ref(); std::vector<CINNValue> res{
ir::Tensor weights_t = weights.as_tensor_ref(); CINNValue(ir_sch.GetModule().GetExprs().at(0))};
CHECK(Out.as_tensor()); *ret = CINNValuePack{res};
pe::CudaScheduleConv(stages, input_t, weights_t, out_t, target); return;
arg_pack[0] = Expr(out_t); } else {
arg_pack[1] = Expr(input_t); CINN_NOT_IMPLEMENTED
arg_pack[2] = Expr(weights_t);
*ret = CINNValuePack{{arg_pack[0], CINNValue(stages)}};
return;
} else if (arg_pack.size() == 13UL) {
Expr wino_weights_dilation = arg_pack[0];
Expr wino_input_pad = arg_pack[1];
Expr wino_A = arg_pack[2];
Expr wino_B = arg_pack[3];
Expr wino_G = arg_pack[4];
Expr kernel_pack = arg_pack[5];
Expr input_tile = arg_pack[6];
Expr data_pack = arg_pack[7];
Expr bgemm = arg_pack[8];
Expr inverse = arg_pack[9];
Expr wino_conv = arg_pack[10];
ir::Tensor wino_weights_dilation_t =
wino_weights_dilation.as_tensor_ref();
ir::Tensor wino_input_pad_t = wino_input_pad.as_tensor_ref();
ir::Tensor wino_A_t = wino_A.as_tensor_ref();
ir::Tensor wino_B_t = wino_B.as_tensor_ref();
ir::Tensor wino_G_t = wino_G.as_tensor_ref();
ir::Tensor kernel_pack_t = kernel_pack.as_tensor_ref();
ir::Tensor input_tile_t = input_tile.as_tensor_ref();
ir::Tensor data_pack_t = data_pack.as_tensor_ref();
ir::Tensor bgemm_t = bgemm.as_tensor_ref();
ir::Tensor inverse_t = inverse.as_tensor_ref();
ir::Tensor wino_conv_t = wino_conv.as_tensor_ref();
std::vector<ir::Tensor> all_tensors = {wino_weights_dilation_t,
wino_input_pad_t,
wino_A_t,
wino_B_t,
wino_G_t,
kernel_pack_t,
input_tile_t,
data_pack_t,
bgemm_t,
inverse_t,
wino_conv_t};
hlir::pe::CudaScheduleWinogradConv(stages, all_tensors, target);
arg_pack[0] = Expr(all_tensors[0]);
arg_pack[1] = Expr(all_tensors[1]);
arg_pack[2] = Expr(all_tensors[2]);
arg_pack[3] = Expr(all_tensors[3]);
arg_pack[4] = Expr(all_tensors[4]);
arg_pack[5] = Expr(all_tensors[5]);
arg_pack[6] = Expr(all_tensors[6]);
arg_pack[7] = Expr(all_tensors[7]);
arg_pack[8] = Expr(all_tensors[8]);
arg_pack[9] = Expr(all_tensors[9]);
arg_pack[10] = Expr(all_tensors[10]);
*ret = CINNValuePack{{arg_pack[10],
arg_pack[5],
arg_pack[7],
arg_pack[8],
CINNValue(stages)}};
return;
}
} else if (target.arch == Target::Arch::X86) {
if (arg_pack.size() == 6UL) {
Expr res = arg_pack[0];
Expr packed_out = arg_pack[1];
Expr weights_dilation = arg_pack[2];
Expr input_pad = arg_pack[3];
Expr data = arg_pack[4];
CHECK(res.as_tensor());
CHECK(packed_out.as_tensor());
CHECK(input_pad.as_tensor());
CHECK(weights_dilation.as_tensor());
CHECK(data.as_tensor());
std::vector<Expr> kernel_shape =
weights_dilation.as_tensor_ref()->shape;
// kernel_h == 1 && kernel_w == 1
CHECK_EQ(kernel_shape.size(), 6U)
<< "kernel_dialtion shape size should be 6";
bool is_1x1 =
(is_zero(kernel_shape[2] - 1)) && (is_zero(kernel_shape[3] - 1));
ir::Tensor packed_out_tensor = packed_out.as_tensor_ref();
bool do_padding = (padding[0] == 0 && padding[1] == 0) ? false : true;
if (groups == 1) {
if (is_1x1) {
pe::Conv2d_NCHWc_1X1_Schedule_CPU(
stages,
res.as_tensor_ref(),
packed_out_tensor,
input_pad.as_tensor_ref(),
weights_dilation.as_tensor_ref(),
data.as_tensor_ref(),
target,
key,
do_padding);
} else {
pe::Conv2d_NCHWc_Schedule_CPU(stages,
res.as_tensor_ref(),
packed_out_tensor,
input_pad.as_tensor_ref(),
weights_dilation.as_tensor_ref(),
data.as_tensor_ref(),
target,
key,
do_padding);
}
if (do_padding) {
*ret = CINNValuePack{{CINNValue(res),
CINNValue(packed_out_tensor),
arg_pack[2],
arg_pack[3],
CINNValue(stages)}};
} else {
*ret = CINNValuePack{{CINNValue(res),
CINNValue(packed_out_tensor),
arg_pack[2],
CINNValue(stages)}};
}
return;
} else {
// todo: opt group_conv schedule
VLOG(3) << "use simple group convolution schedule";
stages[input_pad.as_tensor_ref()]->ComputeInline();
stages[weights_dilation.as_tensor_ref()]->ComputeInline();
stages[data.as_tensor_ref()]->ComputeInline();
*ret = CINNValuePack{
{arg_pack[0], CINNValue(packed_out_tensor), CINNValue(stages)}};
}
return;
} else if (arg_pack.size() == 4UL) {
Expr input_pad = arg_pack[1];
CHECK(input_pad.as_tensor());
stages[input_pad.as_tensor_ref()]->ComputeInline();
Expr weights_dilation = arg_pack[2];
CHECK(weights_dilation.as_tensor());
stages[weights_dilation.as_tensor_ref()]->ComputeInline();
*ret = CINNValuePack{{arg_pack[0], CINNValue(stages)}};
return;
}
} }
*ret = arg_pack; } else if (target.arch == Target::Arch::X86) {
CINN_NOT_IMPLEMENTED
} }
LOG(FATAL) << "This target [" << target << "] is not supported yet.";
}); });
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
...@@ -1007,12 +825,9 @@ std::shared_ptr<OpStrategy> StrategyForDepthwiseConv2d( ...@@ -1007,12 +825,9 @@ std::shared_ptr<OpStrategy> StrategyForDepthwiseConv2d(
CHECK(data_format == "NCHW" || data_format == "NHWC") CHECK(data_format == "NCHW" || data_format == "NHWC")
<< "only support NCHW/NHWC data_format.\n"; << "only support NCHW/NHWC data_format.\n";
std::vector<ir::Tensor> out; std::vector<ir::Tensor> out;
std::string tensor_name = UniqName("Depthwise_Conv2d_out"); CHECK_GE(pack_args.size(), 3);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[2].is_string());
CHECK_GE(pack_args.size(), 3); std::string tensor_name = pack_args[2].operator std::string();
CHECK(pack_args[2].is_string());
tensor_name = pack_args[2].operator std::string();
}
if (data_format == "NCHW") { if (data_format == "NCHW") {
if (target.arch == Target::Arch::X86) { if (target.arch == Target::Arch::X86) {
out = pe::Conv2d_NCHW_5D(A.as_tensor_ref(), out = pe::Conv2d_NCHW_5D(A.as_tensor_ref(),
...@@ -1060,96 +875,35 @@ std::shared_ptr<OpStrategy> StrategyForDepthwiseConv2d( ...@@ -1060,96 +875,35 @@ std::shared_ptr<OpStrategy> StrategyForDepthwiseConv2d(
*ret = CINNValuePack{res}; *ret = CINNValuePack{res};
}); });
framework::CINNSchedule depthwise_conv2d_schedule([=](lang::Args args, framework::CINNSchedule depthwise_conv2d_schedule(
lang::RetValue *ret) { [=](lang::Args args, lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty()) << "The input argument of InjectiveSchedule is "
CHECK(!args.empty()) << "The input argument of InjectiveSchedule is " "empty! Please check.\n";
"empty! Please check.\n"; common::CINNValuePack arg_pack = args[0];
common::CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; std::vector<Expr> vec_tensor;
std::vector<Expr> vec_tensor; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp); } else if (arg_pack[i].is_tensor()) {
} else if (arg_pack[i].is_tensor()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_tensor.emplace_back(temp);
vec_tensor.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
if (target.arch == Target::Arch::NVGPU) {
pe::IRCudaScheduleDepthwiseConv(ir_sch, vec_tensor);
} else {
CINN_NOT_IMPLEMENTED
}
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty()) << "The input argument of depthwise_conv schedule "
"is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL ||
arg_pack.size() == 6UL);
poly::StageMap stages = arg_pack[arg_pack.size() - 1];
Expr Out = arg_pack[0];
CHECK(Out.as_tensor());
if (arg_pack.size() == 3UL) {
Expr input_pad = arg_pack[1];
CHECK(input_pad.as_tensor());
stages[input_pad.as_tensor_ref()]->ComputeInline();
}
if (target.arch == Target::Arch::NVGPU) {
ir::Tensor output = Out.as_tensor_ref();
CHECK(Out.as_tensor());
pe::CudaScheduleDepthwiseConv(stages, output, target);
arg_pack[0] = Expr(output);
} else if (target.arch == Target::Arch::X86) {
if (arg_pack.size() == 6UL) {
Expr res = arg_pack[0];
Expr packed_out = arg_pack[1];
Expr weights_dilation = arg_pack[2];
Expr input_pad = arg_pack[3];
Expr data = arg_pack[4];
CHECK(res.as_tensor());
CHECK(packed_out.as_tensor());
CHECK(input_pad.as_tensor());
CHECK(weights_dilation.as_tensor());
CHECK(data.as_tensor());
ir::Tensor packed_out_tensor = packed_out.as_tensor_ref();
bool do_padding = (padding[0] == 0 && padding[1] == 0) ? false : true;
pe::Depthwise_Conv2d_NCHWc_Schedule_CPU_Nofuse(
stages,
res.as_tensor_ref(),
packed_out_tensor,
input_pad.as_tensor_ref(),
weights_dilation.as_tensor_ref(),
data.as_tensor_ref(),
target,
do_padding);
if (do_padding) {
*ret = CINNValuePack{{CINNValue(res),
CINNValue(packed_out_tensor),
arg_pack[2],
arg_pack[3],
CINNValue(stages)}};
} else {
*ret = CINNValuePack{{CINNValue(res),
CINNValue(packed_out_tensor),
arg_pack[2],
CINNValue(stages)}};
} }
return;
} }
} CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
*ret = CINNValuePack{{arg_pack[0], CINNValue(stages)}}; ir::IRSchedule ir_sch(mod_expr);
} ir_sch.MergeExprs();
}); if (target.arch == Target::Arch::NVGPU) {
pe::IRCudaScheduleDepthwiseConv(ir_sch, vec_tensor);
} else {
CINN_NOT_IMPLEMENTED
}
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
});
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
CHECK(out_type.size()) CHECK(out_type.size())
...@@ -1259,13 +1013,9 @@ std::shared_ptr<OpStrategy> StrategyForBatchNorm( ...@@ -1259,13 +1013,9 @@ std::shared_ptr<OpStrategy> StrategyForBatchNorm(
Expr Bias = arg_pack[2]; Expr Bias = arg_pack[2];
Expr Mean = arg_pack[3]; Expr Mean = arg_pack[3];
Expr Variance = arg_pack[4]; Expr Variance = arg_pack[4];
std::string out_name = UniqName("BatchNorm_output"); CHECK_EQ(arg_pack.size(), 6U);
if (FLAGS_cinn_ir_schedule) { CHECK(arg_pack[5].is_string());
CHECK_EQ(arg_pack.size(), 6U); std::string out_name = arg_pack[5];
CHECK(arg_pack[5].is_string());
std::string str = arg_pack[5];
out_name = str;
}
CHECK(A.as_tensor()); CHECK(A.as_tensor());
CHECK(Scale.as_tensor()); CHECK(Scale.as_tensor());
CHECK(Bias.as_tensor()); CHECK(Bias.as_tensor());
...@@ -1401,12 +1151,9 @@ std::shared_ptr<OpStrategy> StrategyForPool1d( ...@@ -1401,12 +1151,9 @@ std::shared_ptr<OpStrategy> StrategyForPool1d(
CHECK(pool_type == "max" || pool_type == "avg") CHECK(pool_type == "max" || pool_type == "avg")
<< "pool_type for pool1d should be max or avg.\n"; << "pool_type for pool1d should be max or avg.\n";
std::string tensor_name = UniqName("Pool1d_out"); CHECK_EQ(pack_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 2); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
auto out = pe::Pool1d(A.as_tensor_ref(), auto out = pe::Pool1d(A.as_tensor_ref(),
kernel_size, kernel_size,
...@@ -1433,66 +1180,43 @@ std::shared_ptr<OpStrategy> StrategyForPool1d( ...@@ -1433,66 +1180,43 @@ std::shared_ptr<OpStrategy> StrategyForPool1d(
framework::CINNSchedule pool1d_schedule([=](lang::Args args, framework::CINNSchedule pool1d_schedule([=](lang::Args args,
lang::RetValue *ret) { lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty())
CHECK(!args.empty()) << "The input argument of pool1d schedule is empty! Please check.\n";
<< "The input argument of pool1d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0];
CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; std::vector<Expr> vec_tensor;
std::vector<Expr> vec_tensor; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp); } else if (arg_pack[i].is_tensor()) {
} else if (arg_pack[i].is_tensor()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_tensor.emplace_back(temp);
vec_tensor.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
if (arg_pack.size() == 3UL) {
CHECK_EQ(vec_tensor.size(), 2);
Expr input_pad = vec_tensor[1];
CHECK(input_pad.as_tensor());
auto block_input_pad = ir_sch.GetBlock(input_pad.as_tensor()->name);
ir_sch.ComputeInline(block_input_pad);
}
if (target.arch == Target::Arch::NVGPU) {
CHECK(!vec_tensor.empty());
Expr Out = vec_tensor[0];
CHECK(Out.as_tensor());
auto loops = ir_sch.GetLoops(Out.as_tensor()->name);
ir_sch.Split(loops[1], {-1, 2});
loops = ir_sch.GetLoops(Out.as_tensor()->name);
ir_sch.Bind(loops[0], "blockIdx.x");
ir_sch.Bind(loops[1], "threadIdx.x");
}
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
} else {
CHECK(!args.empty())
<< "The input argument of pool1d schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL);
Expr Out = arg_pack[0];
poly::StageMap stages = arg_pack[arg_pack.size() - 1];
if (arg_pack.size() == 3UL) {
Expr input_pad = arg_pack[1];
CHECK(input_pad.as_tensor());
stages[input_pad.as_tensor_ref()]->ComputeInline();
}
if (target.arch == Target::Arch::NVGPU) {
CHECK(Out.as_tensor());
stages[Out.as_tensor_ref()]->Split(1, 2);
stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x");
stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x");
} }
*ret = CINNValuePack{{CINNValue(Out), CINNValue(stages)}};
} }
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
if (arg_pack.size() == 3UL) {
CHECK_EQ(vec_tensor.size(), 2);
Expr input_pad = vec_tensor[1];
CHECK(input_pad.as_tensor());
auto block_input_pad = ir_sch.GetBlock(input_pad.as_tensor()->name);
ir_sch.ComputeInline(block_input_pad);
}
if (target.arch == Target::Arch::NVGPU) {
CHECK(!vec_tensor.empty());
Expr Out = vec_tensor[0];
CHECK(Out.as_tensor());
auto loops = ir_sch.GetLoops(Out.as_tensor()->name);
ir_sch.Split(loops[1], {-1, 2});
loops = ir_sch.GetLoops(Out.as_tensor()->name);
ir_sch.Bind(loops[0], "blockIdx.x");
ir_sch.Bind(loops[1], "threadIdx.x");
}
std::vector<CINNValue> res{CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
}); });
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
...@@ -1668,12 +1392,9 @@ std::shared_ptr<OpStrategy> StrategyForPool2d( ...@@ -1668,12 +1392,9 @@ std::shared_ptr<OpStrategy> StrategyForPool2d(
CHECK(A.as_tensor()); CHECK(A.as_tensor());
ir::Tensor A_tensor = A.as_tensor_ref(); ir::Tensor A_tensor = A.as_tensor_ref();
std::string tensor_name = UniqName("GlobalPool2d_out"); CHECK_EQ(pack_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 2); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
auto out = pe::GlobalPool2d(A_tensor, pool_type, tensor_name); auto out = pe::GlobalPool2d(A_tensor, pool_type, tensor_name);
CHECK(out.size() == 2U) CHECK(out.size() == 2U)
...@@ -1687,44 +1408,31 @@ std::shared_ptr<OpStrategy> StrategyForPool2d( ...@@ -1687,44 +1408,31 @@ std::shared_ptr<OpStrategy> StrategyForPool2d(
lang::RetValue *ret) { lang::RetValue *ret) {
CHECK(!args.empty()) CHECK(!args.empty())
<< "The input argument of pool2d schedule is empty! Please check.\n"; << "The input argument of pool2d schedule is empty! Please check.\n";
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty())
CHECK(!args.empty()) << "The input argument of pool1d schedule is empty! Please check.\n";
<< "The input argument of pool1d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0];
CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; std::vector<Expr> vec_tensor;
std::vector<Expr> vec_tensor; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp); } else if (arg_pack[i].is_tensor()) {
} else if (arg_pack[i].is_tensor()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_tensor.emplace_back(temp);
vec_tensor.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
if (target.arch == Target::Arch::NVGPU) {
pe::IRGlobalPoolScheduleGPU(ir_sch, target);
} else {
CINN_NOT_IMPLEMENTED
} }
std::vector<CINNValue> res{ }
CINNValue(ir_sch.GetModule().GetExprs().at(0))}; CHECK(!vec_ast.empty());
*ret = CINNValuePack{res}; ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
if (target.arch == Target::Arch::NVGPU) {
pe::IRGlobalPoolScheduleGPU(ir_sch, target);
} else { } else {
CINNValuePack arg_pack = args[0]; CINN_NOT_IMPLEMENTED
CHECK_EQ(arg_pack.size(), 3UL);
Expr out = arg_pack[0];
Expr reduce = arg_pack[1];
CHECK(out.as_tensor() && reduce.as_tensor());
poly::StageMap stages = arg_pack[arg_pack.size() - 1];
pe::GlobalPoolScheduleGPU(
stages, {out.as_tensor_ref(), reduce.as_tensor_ref()}, target);
*ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}};
} }
std::vector<CINNValue> res{CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
}); });
framework::CINNCompute pool2d_compute( framework::CINNCompute pool2d_compute(
...@@ -1736,12 +1444,9 @@ std::shared_ptr<OpStrategy> StrategyForPool2d( ...@@ -1736,12 +1444,9 @@ std::shared_ptr<OpStrategy> StrategyForPool2d(
CHECK(A.as_tensor()); CHECK(A.as_tensor());
ir::Tensor A_tensor = A.as_tensor_ref(); ir::Tensor A_tensor = A.as_tensor_ref();
std::string tensor_name = UniqName("Pool2d_out"); CHECK_EQ(pack_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 2); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
auto out = pe::Pool2d(A_tensor, auto out = pe::Pool2d(A_tensor,
kernel_size, kernel_size,
...@@ -1770,63 +1475,41 @@ std::shared_ptr<OpStrategy> StrategyForPool2d( ...@@ -1770,63 +1475,41 @@ std::shared_ptr<OpStrategy> StrategyForPool2d(
framework::CINNSchedule pool2d_schedule([=](lang::Args args, framework::CINNSchedule pool2d_schedule([=](lang::Args args,
lang::RetValue *ret) { lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty())
CHECK(!args.empty()) << "The input argument of pool2d schedule is empty! Please check.\n";
<< "The input argument of pool2d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0];
CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; std::vector<Expr> vec_tensor;
std::vector<Expr> vec_tensor; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp); } else if (arg_pack[i].is_tensor()) {
} else if (arg_pack[i].is_tensor()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_tensor.emplace_back(temp);
vec_tensor.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
int arg_pack_size = arg_pack.size();
// arg_pack_size == 3 case: input, input_pad, output
// arg_pack_size == 4 case: input, input_pad, output, stage
if (arg_pack_size == 3UL || arg_pack_size == 4UL) {
CHECK_EQ(vec_tensor.size(), 2);
Expr input_pad = vec_tensor[1];
CHECK(input_pad.as_tensor());
const std::string &input_pad_name = input_pad.as_tensor()->name;
VLOG(6) << "ComputeInline on " << input_pad_name;
auto block_input_pad = ir_sch.GetBlock(input_pad_name);
ir_sch.ComputeInline(block_input_pad);
}
if (target.arch == Target::Arch::NVGPU) {
pe::IRPoolScheduleGPU(ir_sch, target, arg_pack_size);
}
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
} else {
CHECK(!args.empty())
<< "The input argument of pool2d schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL);
Expr Out = arg_pack[0];
CHECK(Out.as_tensor());
poly::StageMap stages = arg_pack[arg_pack.size() - 1];
if (arg_pack.size() == 3UL) {
Expr input_pad = arg_pack[1];
CHECK(input_pad.as_tensor());
stages[input_pad.as_tensor_ref()]->ComputeInline();
}
ir::Tensor temp_out = Out.as_tensor_ref();
if (target.arch == Target::Arch::NVGPU) {
pe::PoolScheduleGPU(stages, temp_out, target);
arg_pack[arg_pack.size() - 2] = Expr(temp_out);
} }
*ret = CINNValuePack{{CINNValue(Out), CINNValue(stages)}};
} }
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
int arg_pack_size = arg_pack.size();
// arg_pack_size == 3 case: input, input_pad, output
// arg_pack_size == 4 case: input, input_pad, output, stage
if (arg_pack_size == 3UL || arg_pack_size == 4UL) {
CHECK_EQ(vec_tensor.size(), 2);
Expr input_pad = vec_tensor[1];
CHECK(input_pad.as_tensor());
const std::string &input_pad_name = input_pad.as_tensor()->name;
VLOG(6) << "ComputeInline on " << input_pad_name;
auto block_input_pad = ir_sch.GetBlock(input_pad_name);
ir_sch.ComputeInline(block_input_pad);
}
if (target.arch == Target::Arch::NVGPU) {
pe::IRPoolScheduleGPU(ir_sch, target, arg_pack_size);
}
std::vector<CINNValue> res{CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
}); });
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
...@@ -1997,12 +1680,9 @@ std::shared_ptr<OpStrategy> StrategyForPool3d( ...@@ -1997,12 +1680,9 @@ std::shared_ptr<OpStrategy> StrategyForPool3d(
CHECK(pool_type == "max" || pool_type == "avg") CHECK(pool_type == "max" || pool_type == "avg")
<< "pool_type for pool3d should be max or avg.\n"; << "pool_type for pool3d should be max or avg.\n";
std::string tensor_name = UniqName("Pool3d_out"); CHECK_EQ(pack_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 2); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
auto out = pe::Pool3d(A.as_tensor_ref(), auto out = pe::Pool3d(A.as_tensor_ref(),
kernel_size, kernel_size,
...@@ -2030,66 +1710,43 @@ std::shared_ptr<OpStrategy> StrategyForPool3d( ...@@ -2030,66 +1710,43 @@ std::shared_ptr<OpStrategy> StrategyForPool3d(
framework::CINNSchedule pool3d_schedule([=](lang::Args args, framework::CINNSchedule pool3d_schedule([=](lang::Args args,
lang::RetValue *ret) { lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty())
CHECK(!args.empty()) << "The input argument of pool3d schedule is empty! Please check.\n";
<< "The input argument of pool3d schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0];
CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; std::vector<Expr> vec_tensor;
std::vector<Expr> vec_tensor; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp); } else if (arg_pack[i].is_tensor()) {
} else if (arg_pack[i].is_tensor()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_tensor.emplace_back(temp);
vec_tensor.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
if (arg_pack.size() == 3UL) {
CHECK_EQ(vec_tensor.size(), 2);
Expr input_pad = vec_tensor[1];
CHECK(input_pad.as_tensor());
auto block_input_pad = ir_sch.GetBlock(input_pad.as_tensor()->name);
ir_sch.ComputeInline(block_input_pad);
}
if (target.arch == Target::Arch::NVGPU) {
CHECK(!vec_tensor.empty());
Expr Out = vec_tensor[0];
CHECK(Out.as_tensor());
auto loops = ir_sch.GetLoops(Out.as_tensor()->name);
ir_sch.Split(loops[1], {-1, 2});
loops = ir_sch.GetLoops(Out.as_tensor()->name);
ir_sch.Bind(loops[0], "blockIdx.x");
ir_sch.Bind(loops[1], "threadIdx.x");
}
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
} else {
CHECK(!args.empty())
<< "The input argument of pool3d schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL);
Expr Out = arg_pack[0];
poly::StageMap stages = arg_pack[arg_pack.size() - 1];
if (arg_pack.size() == 3UL) {
Expr input_pad = arg_pack[1];
CHECK(input_pad.as_tensor());
stages[input_pad.as_tensor_ref()]->ComputeInline();
}
if (target.arch == Target::Arch::NVGPU) {
CHECK(Out.as_tensor());
stages[Out.as_tensor_ref()]->Split(1, 2);
stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x");
stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x");
} }
*ret = CINNValuePack{{CINNValue(Out), CINNValue(stages)}};
} }
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
if (arg_pack.size() == 3UL) {
CHECK_EQ(vec_tensor.size(), 2);
Expr input_pad = vec_tensor[1];
CHECK(input_pad.as_tensor());
auto block_input_pad = ir_sch.GetBlock(input_pad.as_tensor()->name);
ir_sch.ComputeInline(block_input_pad);
}
if (target.arch == Target::Arch::NVGPU) {
CHECK(!vec_tensor.empty());
Expr Out = vec_tensor[0];
CHECK(Out.as_tensor());
auto loops = ir_sch.GetLoops(Out.as_tensor()->name);
ir_sch.Split(loops[1], {-1, 2});
loops = ir_sch.GetLoops(Out.as_tensor()->name);
ir_sch.Bind(loops[0], "blockIdx.x");
ir_sch.Bind(loops[1], "threadIdx.x");
}
std::vector<CINNValue> res{CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
}); });
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
...@@ -2236,12 +1893,10 @@ std::shared_ptr<OpStrategy> StrategyForSoftmax( ...@@ -2236,12 +1893,10 @@ std::shared_ptr<OpStrategy> StrategyForSoftmax(
} }
std::vector<ir::Tensor> out; std::vector<ir::Tensor> out;
std::string tensor_name = UniqName("Softmax_out"); CHECK_GE(pack_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[pack_args.size() - 1].is_string());
CHECK_GE(pack_args.size(), 2); std::string tensor_name =
CHECK(pack_args[pack_args.size() - 1].is_string()); pack_args[pack_args.size() - 1].operator std::string();
tensor_name = pack_args[pack_args.size() - 1].operator std::string();
}
#ifdef CINN_WITH_MKLDNN #ifdef CINN_WITH_MKLDNN
if (use_mkldnn) { if (use_mkldnn) {
...@@ -2267,78 +1922,50 @@ std::shared_ptr<OpStrategy> StrategyForSoftmax( ...@@ -2267,78 +1922,50 @@ std::shared_ptr<OpStrategy> StrategyForSoftmax(
framework::CINNSchedule softmax_schedule([=](lang::Args args, framework::CINNSchedule softmax_schedule([=](lang::Args args,
lang::RetValue *ret) { lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty())
CHECK(!args.empty()) << "The input arguments of softmax schedule is empty! Please check.";
<< "The input arguments of softmax schedule is empty! Please check."; CINNValuePack arg_pack = args[0];
CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp);
}
} }
CHECK(!vec_ast.empty()); }
ir::ModuleExpr mod_expr(vec_ast); CHECK(!vec_ast.empty());
ir::IRSchedule ir_sch(mod_expr); ir::ModuleExpr mod_expr(vec_ast);
ir_sch.MergeExprs(); ir::IRSchedule ir_sch(mod_expr);
if (target.arch == Target::Arch::NVGPU) { ir_sch.MergeExprs();
if (output_shapes[0].size() > 1) { if (target.arch == Target::Arch::NVGPU) {
auto all_blocks = ir_sch.GetAllBlocks(); if (output_shapes[0].size() > 1) {
CHECK_EQ(all_blocks.size(), 3); auto all_blocks = ir_sch.GetAllBlocks();
auto loops = ir_sch.GetLoops(all_blocks[2]); CHECK_EQ(all_blocks.size(), 3);
ir_sch.ComputeAt(all_blocks[1], loops.back()); auto loops = ir_sch.GetLoops(all_blocks[2]);
ir_sch.ComputeAt(all_blocks[1], loops.back());
if (output_shapes[0][0] != 1) {
ir_sch.SimpleComputeAt(all_blocks[0], loops[0]); if (output_shapes[0][0] != 1) {
} ir_sch.SimpleComputeAt(all_blocks[0], loops[0]);
}
loops = ir_sch.GetLoops(all_blocks[2]); loops = ir_sch.GetLoops(all_blocks[2]);
int loop_index = 1; int loop_index = 1;
if (output_shapes[0][0] == 1) loop_index--; if (output_shapes[0][0] == 1) loop_index--;
CHECK_GE(loops.size(), loop_index + 1); CHECK_GE(loops.size(), loop_index + 1);
auto splited_loops = ir_sch.Split(loops[loop_index], {-1, 5}); auto splited_loops = ir_sch.Split(loops[loop_index], {-1, 5});
all_blocks = ir_sch.GetAllBlocks(); all_blocks = ir_sch.GetAllBlocks();
loops = ir_sch.GetLoops(all_blocks[2]); loops = ir_sch.GetLoops(all_blocks[2]);
ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[0], "blockIdx.x");
ir_sch.Bind(loops[1], "threadIdx.x"); ir_sch.Bind(loops[1], "threadIdx.x");
}
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
} else if (target.arch == Target::Arch::X86) {
pe::IRSoftmaxScheduleCPU(ir_sch, axis);
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
}
} else {
CHECK(!args.empty())
<< "The input arguments of softmax schedule is empty! Please check.";
CINNValuePack arg_pack = args[0];
CHECK_EQ(arg_pack.size(), 3UL)
<< "The input tensor's size of softmax schedule is "
<< arg_pack.size() << "and it should be equal to 3! Please check.";
Expr out1 = arg_pack[0];
Expr out2 = arg_pack[1];
poly::StageMap stages = arg_pack[2];
CHECK(out1.as_tensor());
CHECK(out2.as_tensor());
ir::Tensor tensor_a = out1.as_tensor_ref();
ir::Tensor tensor_b = out2.as_tensor_ref();
if (target.arch == Target::Arch::NVGPU) {
if (tensor_a->shape.size() > 1) {
stages[tensor_a]->Split(1, 5);
stages[tensor_a]->Bind(0, "blockIdx.x");
stages[tensor_a]->Bind(1, "threadIdx.x");
int shape_size = tensor_a->shape.size();
stages[tensor_b]->ComputeAt(stages[tensor_a], shape_size);
}
} else if (target.arch == Target::Arch::X86) {
pe::SoftmaxScheduleCPU(stages, tensor_a, tensor_b, axis);
} }
*ret = arg_pack; std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
} else if (target.arch == Target::Arch::X86) {
pe::IRSoftmaxScheduleCPU(ir_sch, axis);
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
} }
}); });
...@@ -2408,12 +2035,9 @@ std::shared_ptr<OpStrategy> StrategyForDropoutInfer( ...@@ -2408,12 +2035,9 @@ std::shared_ptr<OpStrategy> StrategyForDropoutInfer(
CHECK(A_expr.as_tensor()); CHECK(A_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref(); ir::Tensor A = A_expr.as_tensor_ref();
std::string tensor_name = UniqName("dropout_infer_out"); CHECK_EQ(pack_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[1].is_string());
CHECK_EQ(pack_args.size(), 2); std::string tensor_name = pack_args[1].operator std::string();
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}
auto out = auto out =
pe::DropoutInfer(A, dropout_prob, dropout_implementation, tensor_name); pe::DropoutInfer(A, dropout_prob, dropout_implementation, tensor_name);
...@@ -2479,12 +2103,9 @@ std::shared_ptr<OpStrategy> StrategyForSelect( ...@@ -2479,12 +2103,9 @@ std::shared_ptr<OpStrategy> StrategyForSelect(
CHECK(true_value.as_tensor()); CHECK(true_value.as_tensor());
CHECK(false_value.as_tensor()); CHECK(false_value.as_tensor());
std::string tensor_name = UniqName("Select_output"); CHECK_EQ(pack_args.size(), 4U);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[3].is_string());
CHECK_EQ(pack_args.size(), 4U); std::string tensor_name = pack_args[3].operator std::string();
CHECK(pack_args[3].is_string());
tensor_name = pack_args[3].operator std::string();
}
auto out = pe::Select(condition.as_tensor_ref(), auto out = pe::Select(condition.as_tensor_ref(),
true_value.as_tensor_ref(), true_value.as_tensor_ref(),
......
...@@ -59,37 +59,19 @@ TEST(Operator, Operator_ElementWise_Add_Test0) { ...@@ -59,37 +59,19 @@ TEST(Operator, Operator_ElementWise_Add_Test0) {
std::string func_name = "add1"; std::string func_name = "add1";
Module::Builder builder("module0", target); Module::Builder builder("module0", target);
if (FLAGS_cinn_ir_schedule) { std::string out_name = "C";
std::string out_name = "C"; common::CINNValuePack cinn_input =
common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A),
common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B),
common::CINNValue(B), common::CINNValue(out_name)}};
common::CINNValue(out_name)}}; std::vector<std::string> input_output_names{"A", "B", out_name};
std::vector<std::string> input_output_names{"A", "B", out_name};
auto funcs = framework::GetFuncFromImpl(
auto funcs = framework::GetFuncFromImpl( impl, cinn_input, inputs, input_output_names, func_name, target);
impl, cinn_input, inputs, input_output_names, func_name, target);
for (auto func : funcs) {
for (auto func : funcs) { LOG(INFO) << "Test Operator_ElementWise_Add_Test0's Strategy, func is :\n"
LOG(INFO) << "Test Operator_ElementWise_Add_Test0's Strategy, func is :\n" << func;
<< func;
builder.AddFunction(func);
}
} else {
common::CINNValuePack cinn_input =
common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}};
common::CINNValuePack rets = impl->fcompute(cinn_input);
ASSERT_EQ(rets.size(), 2UL);
rets = impl->fschedule(rets);
ASSERT_EQ(rets.size(), 2UL);
// the last element is a StageMap
for (int i = 0; i < rets->size() - 1; i++) {
Expr temp = rets[i];
inputs.push_back(temp.as_tensor_ref());
}
auto func = Lower("fn_" + func_name, rets.back(), inputs);
LOG(INFO) << "Test Strategy Codegen:\n" << func;
builder.AddFunction(func); builder.AddFunction(func);
} }
...@@ -160,37 +142,20 @@ TEST(Operator, Operator_ElementWise_Add_Test1) { ...@@ -160,37 +142,20 @@ TEST(Operator, Operator_ElementWise_Add_Test1) {
std::string func_name = "add2"; std::string func_name = "add2";
Module::Builder builder("module", target); Module::Builder builder("module", target);
if (FLAGS_cinn_ir_schedule) { std::string out_name = "C";
std::string out_name = "C"; common::CINNValuePack cinn_input =
common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A),
common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B),
common::CINNValue(B), common::CINNValue(out_name)}};
common::CINNValue(out_name)}}; std::vector<std::string> input_output_names{"A", "B", out_name};
std::vector<std::string> input_output_names{"A", "B", out_name};
auto funcs = framework::GetFuncFromImpl(
auto funcs = framework::GetFuncFromImpl( impl, cinn_input, inputs, input_output_names, func_name, target);
impl, cinn_input, inputs, input_output_names, func_name, target);
for (auto func : funcs) {
for (auto func : funcs) {
builder.AddFunction(func);
LOG(INFO) << "Test Operator_ElementWise_Add_Test1's Strategy, func is :\n"
<< func;
}
} else {
common::CINNValuePack cinn_input =
common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}};
common::CINNValuePack rets = impl->fcompute(cinn_input);
ASSERT_EQ(rets.size(), 2UL);
rets = impl->fschedule(rets);
ASSERT_EQ(rets.size(), 2UL);
// the last element is a StageMap
for (int i = 0; i < rets->size() - 1; i++) {
Expr temp = rets[i];
inputs.push_back(temp.as_tensor_ref());
}
auto func = Lower("fn_" + func_name, rets.back(), inputs);
LOG(INFO) << "Test Strategy Codegen:\n" << func;
builder.AddFunction(func); builder.AddFunction(func);
LOG(INFO) << "Test Operator_ElementWise_Add_Test1's Strategy, func is :\n"
<< func;
} }
backends::CodeGenCUDA_Dev codegen(target); backends::CodeGenCUDA_Dev codegen(target);
...@@ -225,33 +190,15 @@ TEST(Operator, Operator_BroadcastTo) { ...@@ -225,33 +190,15 @@ TEST(Operator, Operator_BroadcastTo) {
std::string func_name = "broadcast_to"; std::string func_name = "broadcast_to";
if (FLAGS_cinn_ir_schedule) { std::string out_name = "C";
std::string out_name = "C"; common::CINNValuePack cinn_input = common::CINNValuePack{
common::CINNValuePack cinn_input = common::CINNValuePack{ {common::CINNValue(B), common::CINNValue(out_name)}};
{common::CINNValue(B), common::CINNValue(out_name)}}; std::vector<std::string> input_output_names{"B", out_name};
std::vector<std::string> input_output_names{"B", out_name};
auto funcs = framework::GetFuncFromImpl( auto funcs = framework::GetFuncFromImpl(
impl, cinn_input, inputs, input_output_names, func_name, target); impl, cinn_input, inputs, input_output_names, func_name, target);
for (auto func : funcs) { for (auto func : funcs) {
LOG(INFO) << "Test Operator_BroadcastTo's Strategy, func is :\n" << func;
}
} else {
common::CINNValuePack cinn_input =
common::CINNValuePack{{common::CINNValue(B)}};
common::CINNValuePack rets = impl->fcompute(cinn_input);
ASSERT_EQ(rets.size(), 2UL);
rets = impl->fschedule(rets);
ASSERT_EQ(rets.size(), 2UL);
// the last element is a StageMap
for (int i = 0; i < rets->size() - 1; i++) {
Expr temp = rets[i];
inputs.push_back(temp.as_tensor_ref());
}
auto func = Lower("func" + func_name, rets.back(), inputs);
LOG(INFO) << "Test Operator_BroadcastTo's Strategy, func is :\n" << func; LOG(INFO) << "Test Operator_BroadcastTo's Strategy, func is :\n" << func;
} }
} }
...@@ -260,9 +207,7 @@ common::CINNValuePack GetComputeResult( ...@@ -260,9 +207,7 @@ common::CINNValuePack GetComputeResult(
const std::shared_ptr<OpImpl> &impl, const std::shared_ptr<OpImpl> &impl,
std::vector<common::CINNValue> &cinn_inputs, // NOLINT std::vector<common::CINNValue> &cinn_inputs, // NOLINT
const std::string &output_name = "") { const std::string &output_name = "") {
if (FLAGS_cinn_ir_schedule) { cinn_inputs.emplace_back(output_name);
cinn_inputs.emplace_back(output_name);
}
return impl->fcompute(common::CINNValuePack{cinn_inputs}); return impl->fcompute(common::CINNValuePack{cinn_inputs});
} }
......
...@@ -21,8 +21,6 @@ ...@@ -21,8 +21,6 @@
#include "paddle/cinn/hlir/pe/schedule.h" #include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/ir/ir_schedule.h" #include "paddle/cinn/ir/ir_schedule.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn { namespace cinn {
namespace hlir { namespace hlir {
...@@ -31,44 +29,24 @@ CINNSchedule GetElementwiseScheduleFunc( ...@@ -31,44 +29,24 @@ CINNSchedule GetElementwiseScheduleFunc(
const Target& target, const Target& target,
bool vectorizable) { bool vectorizable) {
return CINNSchedule([=](lang::Args args, lang::RetValue* ret) { return CINNSchedule([=](lang::Args args, lang::RetValue* ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty()) << "The input argument of ElementwiseSchedule is "
CHECK(!args.empty()) << "The input argument of ElementwiseSchedule is " "empty! Please check.\n";
"empty! Please check.\n"; common::CINNValuePack arg_pack = args[0];
common::CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
pe::IRElementwiseSchedule(ir_sch, output_shapes.front(), target);
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty()) << "The input argument of ElementwiseSchedule is "
"empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(out.as_tensor());
CHECK_EQ(arg_pack.size(), 2UL);
if (target.arch == Target::Arch::NVGPU) {
pe::CudaScheduleInjective(
stages[out.as_tensor_ref()], output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()],
output_shapes.front(),
target,
vectorizable);
} }
*ret = arg_pack;
} }
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
pe::IRElementwiseSchedule(ir_sch, output_shapes.front(), target);
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
}); });
} }
...@@ -77,50 +55,30 @@ CINNSchedule GetInjectiveScheduleFunc( ...@@ -77,50 +55,30 @@ CINNSchedule GetInjectiveScheduleFunc(
const Target& target, const Target& target,
bool vectorizable) { bool vectorizable) {
return CINNSchedule([=](lang::Args args, lang::RetValue* ret) { return CINNSchedule([=](lang::Args args, lang::RetValue* ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty()) << "The input argument of InjectiveSchedule is "
CHECK(!args.empty()) << "The input argument of InjectiveSchedule is " "empty! Please check.\n";
"empty! Please check.\n"; common::CINNValuePack arg_pack = args[0];
common::CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp);
}
} }
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
pe::IRInjectiveSchedule(ir_sch, output_shapes.front(), target);
/*if (target.arch == Target::Arch::NVGPU) {
pe::IRInjectiveSchedule(ir_sch, output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target,
vectorizable);
}*/
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty()) << "The input argument of InjectiveSchedule is "
"empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(out.as_tensor());
CHECK_EQ(arg_pack.size(), 2UL);
if (target.arch == Target::Arch::NVGPU) {
pe::CudaScheduleInjective(
stages[out.as_tensor_ref()], output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()],
output_shapes.front(),
target,
vectorizable);
}
*ret = arg_pack;
} }
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
pe::IRInjectiveSchedule(ir_sch, output_shapes.front(), target);
/*if (target.arch == Target::Arch::NVGPU) {
pe::IRInjectiveSchedule(ir_sch, output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target,
vectorizable);
}*/
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
}); });
} }
......
...@@ -29,8 +29,6 @@ ...@@ -29,8 +29,6 @@
#include "paddle/cinn/ir/ir_schedule.h" #include "paddle/cinn/ir/ir_schedule.h"
#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/ir_simplify.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn { namespace cinn {
namespace hlir { namespace hlir {
namespace op { namespace op {
...@@ -115,16 +113,10 @@ std::shared_ptr<OpStrategy> StrategyForReduce( ...@@ -115,16 +113,10 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
CHECK(!args.empty()) << "The input argument of " << op_name CHECK(!args.empty()) << "The input argument of " << op_name
<< " compute is empty! Please check."; << " compute is empty! Please check.";
CINNValuePack arg_packs = args[0]; CINNValuePack arg_packs = args[0];
std::string tensor_name = UniqName(op_name + "_out"); CHECK_EQ(arg_packs.size(), 2U)
if (FLAGS_cinn_ir_schedule) { << "There should be 2 input args for " << op_name << " compute";
CHECK_EQ(arg_packs.size(), 2U) CHECK(arg_packs[1].is_string());
<< "There should be 2 input args for " << op_name << " compute"; std::string tensor_name = arg_packs[1].operator std::string();
CHECK(arg_packs[1].is_string());
tensor_name = arg_packs[1].operator std::string();
} else {
CHECK_EQ(arg_packs.size(), 1U)
<< "There should be 1 input args for " << op_name << " compute";
}
Expr x_expr = arg_packs[0]; Expr x_expr = arg_packs[0];
CHECK(x_expr.as_tensor()); CHECK(x_expr.as_tensor());
ir::Tensor x = x_expr.as_tensor_ref(); ir::Tensor x = x_expr.as_tensor_ref();
...@@ -175,206 +167,137 @@ std::shared_ptr<OpStrategy> StrategyForReduce( ...@@ -175,206 +167,137 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
lang::RetValue *ret) { lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of " << op_name CHECK(!args.empty()) << "The input argument of " << op_name
<< " schedule is empty! Please check."; << " schedule is empty! Please check.";
CINNValuePack arg_pack = args[0];
if (FLAGS_cinn_ir_schedule) { CINNValuePack arg_pack = args[0];
CHECK_GE(arg_pack.size(), 2UL); CHECK_GE(arg_pack.size(), 2UL);
CHECK_LE(arg_pack.size(), 8UL); CHECK_LE(arg_pack.size(), 8UL);
CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; std::vector<Expr> vec_tensor;
std::vector<Expr> vec_tensor; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; // TODO(zhhsplendid): old reducetion schedule assumes all length-1
// TODO(zhhsplendid): old reducetion schedule assumes all length-1 // for loops are simplified, but it is not after we add length-1
// for loops are simplified, but it is not after we add length-1 // back. Reduction schedule is complex and we haven't changed it to
// back. Reduction schedule is complex and we haven't changed it to // support the length-1 for loop yet. So we simplify here. The todo
// support the length-1 for loop yet. So we simplify here. The todo // is that remove SimplifyForLoops below and change reduction schedule
// is that remove SimplifyForLoops below and change reduction schedule optim::SimplifyForLoops(&temp);
optim::SimplifyForLoops(&temp); vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp); } else if (arg_pack[i].is_tensor()) {
} else if (arg_pack[i].is_tensor()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_tensor.emplace_back(temp);
vec_tensor.emplace_back(temp);
}
} }
CHECK(!vec_ast.empty()); }
ir::ModuleExpr mod_expr(vec_ast); CHECK(!vec_ast.empty());
ir::IRSchedule ir_sch(mod_expr); ir::ModuleExpr mod_expr(vec_ast);
ir_sch.MergeExprs(); ir::IRSchedule ir_sch(mod_expr);
if (target.arch == Target::Arch::NVGPU) { ir_sch.MergeExprs();
if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) { if (target.arch == Target::Arch::NVGPU) {
if (arg_pack.size() == 4) { if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) {
CHECK_EQ(vec_tensor.size(), 2); if (arg_pack.size() == 4) {
Expr out = vec_tensor[0]; CHECK_EQ(vec_tensor.size(), 2);
Expr tmp_out = vec_tensor[1]; Expr out = vec_tensor[0];
Expr tmp_out = vec_tensor[1];
VLOG(3) << "Do IRCudaScheduleBlockReduceInternal Schedule!";
pe::IRCudaScheduleBlockReduceInternal( VLOG(3) << "Do IRCudaScheduleBlockReduceInternal Schedule!";
ir_sch, tmp_out.as_tensor_ref(), out.as_tensor_ref(), target); pe::IRCudaScheduleBlockReduceInternal(
ir_sch, tmp_out.as_tensor_ref(), out.as_tensor_ref(), target);
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))}; std::vector<CINNValue> res{
*ret = CINNValuePack{res}; CINNValue(ir_sch.GetModule().GetExprs().at(0))};
} else if (arg_pack.size() == 6) { *ret = CINNValuePack{res};
CHECK_EQ(vec_tensor.size(), 3); } else if (arg_pack.size() == 6) {
Expr out = vec_tensor[0]; CHECK_EQ(vec_tensor.size(), 3);
Expr tmp_out = vec_tensor[1]; Expr out = vec_tensor[0];
Expr reduce_tmp_out = vec_tensor[2]; Expr tmp_out = vec_tensor[1];
Expr reduce_tmp_out = vec_tensor[2];
VLOG(3) << "Do IRCudaScheduleBlockReduce Schedule!";
pe::IRCudaScheduleBlockReduce(ir_sch, VLOG(3) << "Do IRCudaScheduleBlockReduce Schedule!";
reduce_tmp_out.as_tensor_ref(), pe::IRCudaScheduleBlockReduce(ir_sch,
tmp_out.as_tensor_ref(), reduce_tmp_out.as_tensor_ref(),
out.as_tensor_ref(), tmp_out.as_tensor_ref(),
target); out.as_tensor_ref(),
target);
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))}; std::vector<CINNValue> res{
*ret = CINNValuePack{res}; CINNValue(ir_sch.GetModule().GetExprs().at(0))};
} else if (arg_pack.size() == 7) { *ret = CINNValuePack{res};
CHECK_EQ(vec_tensor.size(), 4); } else if (arg_pack.size() == 7) {
Expr out = vec_tensor[0]; CHECK_EQ(vec_tensor.size(), 4);
Expr tmp_out = vec_tensor[1]; Expr out = vec_tensor[0];
Expr reduce_tmp_out = vec_tensor[2]; Expr tmp_out = vec_tensor[1];
Expr reshape = vec_tensor[3]; Expr reduce_tmp_out = vec_tensor[2];
Expr reshape = vec_tensor[3];
VLOG(3) << "Do IRCudaTwoStepReduceSchedule Schedule!";
pe::IRCudaTwoStepReduceSchedule(ir_sch, VLOG(3) << "Do IRCudaTwoStepReduceSchedule Schedule!";
reshape.as_tensor_ref(), pe::IRCudaTwoStepReduceSchedule(ir_sch,
reduce_tmp_out.as_tensor_ref(), reshape.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
common::DefaultNVGPUTarget());
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
} else if (arg_pack.size() == 5) {
CHECK_EQ(vec_tensor.size(), 3);
Expr out = vec_tensor[0];
Expr tmp_out = vec_tensor[1];
Expr reduce_tmp_out = vec_tensor[2];
VLOG(3) << "Do IRCudaScheduleBlockReduce Schedule!";
pe::IRCudaScheduleBlockReduce(ir_sch,
reduce_tmp_out.as_tensor_ref(), reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(), tmp_out.as_tensor_ref(),
out.as_tensor_ref(), out.as_tensor_ref(),
common::DefaultNVGPUTarget()); common::DefaultNVGPUTarget());
std::vector<CINNValue> res{ std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))}; CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res}; *ret = CINNValuePack{res};
} else { } else if (arg_pack.size() == 5) {
LOG(FATAL) << "Unkown Reduce Type!"; CHECK_EQ(vec_tensor.size(), 3);
} Expr out = vec_tensor[0];
} else { Expr tmp_out = vec_tensor[1];
if (arg_pack.size() == 2) { Expr reduce_tmp_out = vec_tensor[2];
CHECK_EQ(vec_tensor.size(), 1);
Expr reduce_out = vec_tensor[0]; VLOG(3) << "Do IRCudaScheduleBlockReduce Schedule!";
pe::IRCudaScheduleBlockReduce(ir_sch,
VLOG(3) << "Do IRCudaScheduleReduce Schedule!";
pe::IRCudaScheduleReduce(
ir_sch,
reduce_out.as_tensor_ref(),
inputs[0]->shape.size() - reduce_axes.back() - 1,
target);
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
} else if (arg_pack.size() == 6) {
CHECK_EQ(vec_tensor.size(), 3);
Expr reduce_out = vec_tensor[0];
Expr reduce_internal = vec_tensor[1];
Expr reduce_reshape = vec_tensor[2];
VLOG(3) << "Do IRCudaScheduleBlockShuffleReduce Schedule!";
pe::IRCudaScheduleBlockShuffleReduce(
ir_sch,
reduce_reshape.as_tensor_ref(),
reduce_internal.as_tensor_ref(),
reduce_out.as_tensor_ref(),
target);
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
} else {
LOG(FATAL) << "Unkown Reduce Type!";
}
}
} else {
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
}
} else {
CHECK_GE(arg_pack.size(), 2UL);
CHECK_LE(arg_pack.size(), 5UL);
if (target.arch == Target::Arch::NVGPU) {
if (!WithoutLastDimInReduce(inputs[0]->shape, reduce_axes)) {
if (arg_pack.size() == 3) {
Expr out = arg_pack[0];
Expr tmp_out = arg_pack[1];
poly::StageMap stages = arg_pack.back();
VLOG(3) << "Do CudaBlockReduceInternalSchedule Schedule!";
pe::CudaBlockReduceInternalSchedule(stages,
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
common::DefaultNVGPUTarget());
} else if (arg_pack.size() == 4) {
Expr out = arg_pack[0];
Expr tmp_out = arg_pack[1];
Expr reduce_tmp_out = arg_pack[2];
poly::StageMap stages = arg_pack.back();
VLOG(3) << "Do CudaBlockReduceSchedule Schedule!";
pe::CudaBlockReduceSchedule(stages,
reduce_tmp_out.as_tensor_ref(), reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(), tmp_out.as_tensor_ref(),
out.as_tensor_ref(), out.as_tensor_ref(),
common::DefaultNVGPUTarget()); common::DefaultNVGPUTarget());
} else {
Expr out = arg_pack[0]; std::vector<CINNValue> res{
Expr tmp_out = arg_pack[1]; CINNValue(ir_sch.GetModule().GetExprs().at(0))};
Expr reduce_tmp_out = arg_pack[2]; *ret = CINNValuePack{res};
Expr reshape = arg_pack[3];
poly::StageMap stages = arg_pack.back();
VLOG(3) << "Do CudaTwoStepReduceSchedule Schedule!";
pe::CudaTwoStepReduceSchedule(stages,
reshape.as_tensor_ref(),
reduce_tmp_out.as_tensor_ref(),
tmp_out.as_tensor_ref(),
out.as_tensor_ref(),
common::DefaultNVGPUTarget());
}
} else { } else {
if (arg_pack.size() == 2) { LOG(FATAL) << "Unkown Reduce Type!";
Expr reduce_out = arg_pack[0]; }
poly::StageMap stages = arg_pack.back(); } else {
VLOG(3) << "Do CudaReduceSchedule Schedule!"; if (arg_pack.size() == 2) {
pe::CudaReduceSchedule( CHECK_EQ(vec_tensor.size(), 1);
stages, Expr reduce_out = vec_tensor[0];
reduce_out.as_tensor_ref(),
inputs[0]->shape.size() - reduce_axes.back() - 1, VLOG(3) << "Do IRCudaScheduleReduce Schedule!";
target); pe::IRCudaScheduleReduce(
} else { ir_sch,
CHECK_EQ(arg_pack.size(), 4) << "args is not equal 4!"; reduce_out.as_tensor_ref(),
Expr reduce_reshape = arg_pack[2]; inputs[0]->shape.size() - reduce_axes.back() - 1,
Expr reduce_internal = arg_pack[1]; target);
Expr reduce_out = arg_pack[0];
poly::StageMap stages = arg_pack.back(); std::vector<CINNValue> res{
VLOG(3) << "Do CudaBlockShuffleReduceSchedule Schedule!"; CINNValue(ir_sch.GetModule().GetExprs().at(0))};
pe::CudaBlockShuffleReduceSchedule(stages, *ret = CINNValuePack{res};
} else if (arg_pack.size() == 6) {
CHECK_EQ(vec_tensor.size(), 3);
Expr reduce_out = vec_tensor[0];
Expr reduce_internal = vec_tensor[1];
Expr reduce_reshape = vec_tensor[2];
VLOG(3) << "Do IRCudaScheduleBlockShuffleReduce Schedule!";
pe::IRCudaScheduleBlockShuffleReduce(ir_sch,
reduce_reshape.as_tensor_ref(), reduce_reshape.as_tensor_ref(),
reduce_internal.as_tensor_ref(), reduce_internal.as_tensor_ref(),
reduce_out.as_tensor_ref(), reduce_out.as_tensor_ref(),
target); target);
}
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
} else {
LOG(FATAL) << "Unkown Reduce Type!";
} }
} }
*ret = arg_pack; } else {
std::vector<CINNValue> res{
CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
} }
}); });
......
...@@ -73,12 +73,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul( ...@@ -73,12 +73,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
CHECK(A.as_tensor()); CHECK(A.as_tensor());
CHECK(B.as_tensor()); CHECK(B.as_tensor());
std::string tensor_name = UniqName("MatMul"); CHECK_GE(pack_args.size(), 3);
if (FLAGS_cinn_ir_schedule) { CHECK(pack_args[2].is_string());
CHECK_GE(pack_args.size(), 3); std::string tensor_name = pack_args[2].operator std::string();
CHECK(pack_args[2].is_string());
tensor_name = pack_args[2].operator std::string();
}
auto tensor_A = A.as_tensor_ref(); auto tensor_A = A.as_tensor_ref();
auto tensor_B = B.as_tensor_ref(); auto tensor_B = B.as_tensor_ref();
...@@ -130,32 +127,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul( ...@@ -130,32 +127,9 @@ std::shared_ptr<OpStrategy> StrategyForMatMul(
CHECK(!args.empty()) CHECK(!args.empty())
<< "The input argument of matmul schedule is empty! Please check.\n"; << "The input argument of matmul schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0]; CINNValuePack arg_pack = args[0];
if (FLAGS_cinn_ir_schedule) { std::vector<CINNValue> results =
std::vector<CINNValue> results = pe::IRCudaScheduleMatMul(arg_pack, output_shape, target);
pe::IRCudaScheduleMatMul(arg_pack, output_shape, target); *ret = CINNValuePack({results});
*ret = CINNValuePack({results});
} else {
CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL);
poly::StageMap stages = arg_pack.back();
if (target.arch == Target::Arch::NVGPU) {
Expr out = arg_pack[0];
CHECK(out.as_tensor());
pe::MatmulScheduleCUDA(stages, out.as_tensor_ref(), target);
} else if (target.arch == Target::Arch::X86) {
#ifdef CINN_WITH_MKL_CBLAS
CHECK_EQ(arg_pack.size(), 3UL);
#else
CHECK_EQ(arg_pack.size(), 3UL);
Expr out = arg_pack[0];
Expr packedB = arg_pack[1];
CHECK(packedB.as_tensor());
CHECK(out.as_tensor());
pe::MatmulScheduleCPU(
stages, out.as_tensor_ref(), packedB.as_tensor_ref(), target);
#endif
}
*ret = arg_pack;
}
}); });
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
...@@ -262,16 +236,10 @@ std::shared_ptr<OpStrategy> StrategyForSplit( ...@@ -262,16 +236,10 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
ir::Tensor A = A_expr.as_tensor_ref(); ir::Tensor A = A_expr.as_tensor_ref();
std::vector<std::string> tensor_names; std::vector<std::string> tensor_names;
if (FLAGS_cinn_ir_schedule) { CHECK_EQ(pack_args.size(), output_shapes.size() + 1);
CHECK_EQ(pack_args.size(), output_shapes.size() + 1); for (int idx = 1; idx < pack_args.size(); ++idx) {
for (int idx = 1; idx < pack_args.size(); ++idx) { CHECK(pack_args[idx].is_string());
CHECK(pack_args[idx].is_string()); tensor_names.push_back(pack_args[idx].operator std::string());
tensor_names.push_back(pack_args[idx].operator std::string());
}
} else {
for (int idx = 0; idx < output_shapes.size(); ++idx) {
tensor_names.push_back(UniqName("T_Split_Out"));
}
} }
auto out = pe::Split(A, axis, output_shapes, tensor_names); auto out = pe::Split(A, axis, output_shapes, tensor_names);
...@@ -285,38 +253,27 @@ std::shared_ptr<OpStrategy> StrategyForSplit( ...@@ -285,38 +253,27 @@ std::shared_ptr<OpStrategy> StrategyForSplit(
*ret = CINNValuePack{res}; *ret = CINNValuePack{res};
}); });
framework::CINNSchedule split_schedule([=](lang::Args args, framework::CINNSchedule split_schedule(
lang::RetValue *ret) { [=](lang::Args args, lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty())
CHECK(!args.empty()) << "The input argument of split schedule is empty! Please check.";
<< "The input argument of split schedule is empty! Please check."; CINNValuePack arg_pack = args[0];
CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp); }
} }
} CHECK(!vec_ast.empty());
CHECK(!vec_ast.empty()); ir::ModuleExpr mod_expr(vec_ast);
ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr);
ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs();
ir_sch.MergeExprs(); pe::IRCudaSplitSchedule(ir_sch, output_shapes, axis, target);
pe::IRCudaSplitSchedule(ir_sch, output_shapes, axis, target); std::vector<CINNValue> res{
std::vector<CINNValue> res{ CINNValue(ir_sch.GetModule().GetExprs().at(0))};
CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res};
*ret = CINNValuePack{res}; });
} else {
CHECK(!args.empty())
<< "The input arguments of split schedule is empty! Please check.";
CINNValuePack arg_pack = args[0];
CHECK_GE(arg_pack.size(), 2UL)
<< "The input tensor's size of split schedule is " << arg_pack.size()
<< "and it should be greater equal to 2! Please check.";
pe::CudaSplitSchedule(&arg_pack, output_shapes, axis, target);
*ret = arg_pack;
}
});
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(split_compute, split_schedule, "strategy.split.x86", 1); strategy->AddImpl(split_compute, split_schedule, "strategy.split.x86", 1);
...@@ -468,8 +425,7 @@ std::shared_ptr<OpStrategy> StrategyForConcat( ...@@ -468,8 +425,7 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
CHECK(!out_type.empty()) CHECK(!out_type.empty())
<< "Output type of Concat is empty! Please check.\n"; << "Output type of Concat is empty! Please check.\n";
CINNValuePack pack_args = args[0]; CINNValuePack pack_args = args[0];
int input_size = int input_size = pack_args.size() - 1;
FLAGS_cinn_ir_schedule ? pack_args.size() - 1 : pack_args.size();
CHECK_GE(input_size, 1UL) CHECK_GE(input_size, 1UL)
<< "at least 2 input tensors for Concat compute\n"; << "at least 2 input tensors for Concat compute\n";
CHECK(!output_shapes.empty()); CHECK(!output_shapes.empty());
...@@ -485,11 +441,8 @@ std::shared_ptr<OpStrategy> StrategyForConcat( ...@@ -485,11 +441,8 @@ std::shared_ptr<OpStrategy> StrategyForConcat(
input_tensors.push_back(tensor.as_tensor_ref()); input_tensors.push_back(tensor.as_tensor_ref());
} }
std::string tensor_name = UniqName("Concat_output"); CHECK(pack_args[input_size].is_string());
if (FLAGS_cinn_ir_schedule) { std::string tensor_name = pack_args[input_size].operator std::string();
CHECK(pack_args[input_size].is_string());
tensor_name = pack_args[input_size].operator std::string();
}
auto stages = CreateStages(input_tensors); auto stages = CreateStages(input_tensors);
auto out = pe::Concat(input_tensors, axis, tensor_name); auto out = pe::Concat(input_tensors, axis, tensor_name);
...@@ -612,11 +565,8 @@ std::shared_ptr<OpStrategy> StrategyForMul( ...@@ -612,11 +565,8 @@ std::shared_ptr<OpStrategy> StrategyForMul(
auto new_B = B_tensor->Reshape(new_shape_B_e, stages); auto new_B = B_tensor->Reshape(new_shape_B_e, stages);
std::vector<ir::Tensor> out; std::vector<ir::Tensor> out;
std::string tensor_name = UniqName("Mul_output"); CHECK(pack_args.back().is_string());
if (FLAGS_cinn_ir_schedule) { std::string tensor_name = pack_args.back().operator std::string();
CHECK(pack_args.back().is_string());
tensor_name = pack_args.back().operator std::string();
}
if (target.arch == Target::Arch::X86) { if (target.arch == Target::Arch::X86) {
#ifdef CINN_WITH_MKL_CBLAS #ifdef CINN_WITH_MKL_CBLAS
...@@ -647,32 +597,9 @@ std::shared_ptr<OpStrategy> StrategyForMul( ...@@ -647,32 +597,9 @@ std::shared_ptr<OpStrategy> StrategyForMul(
CHECK(!args.empty()) CHECK(!args.empty())
<< "The input argument of matmul schedule is empty! Please check.\n"; << "The input argument of matmul schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0]; CINNValuePack arg_pack = args[0];
if (FLAGS_cinn_ir_schedule) { std::vector<CINNValue> results =
std::vector<CINNValue> results = pe::IRCudaScheduleMatMul(arg_pack, output_shape, target);
pe::IRCudaScheduleMatMul(arg_pack, output_shape, target); *ret = CINNValuePack({results});
*ret = CINNValuePack({results});
} else {
CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL);
poly::StageMap stages = arg_pack.back();
if (target.arch == Target::Arch::NVGPU) {
Expr out = arg_pack[0];
CHECK(out.as_tensor());
pe::MatmulScheduleCUDA(stages, out.as_tensor_ref(), target);
} else if (target.arch == Target::Arch::X86) {
#ifdef CINN_WITH_MKL_CBLAS
CHECK_EQ(arg_pack.size(), 3UL);
#else
CHECK_EQ(arg_pack.size(), 3UL);
Expr out = arg_pack[0];
Expr packedB = arg_pack[1];
CHECK(packedB.as_tensor());
CHECK(out.as_tensor());
pe::MatmulScheduleCPU(
stages, out.as_tensor_ref(), packedB.as_tensor_ref(), target);
#endif
}
*ret = arg_pack;
}
}); });
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
...@@ -780,12 +707,9 @@ std::shared_ptr<OpStrategy> StrategyForCublasGemm( ...@@ -780,12 +707,9 @@ std::shared_ptr<OpStrategy> StrategyForCublasGemm(
// dummy gemm computation, which will be replaced by // dummy gemm computation, which will be replaced by
// cinn_gpu_cublas_gemm in the GemmRewriter pass. // cinn_gpu_cublas_gemm in the GemmRewriter pass.
std::string tensor_name = UniqName("cublas_gemm_output"); CHECK_EQ(input_args.size(), 4);
if (FLAGS_cinn_ir_schedule) { CHECK(input_args[3].is_string());
CHECK_EQ(input_args.size(), 4); std::string tensor_name = input_args[3].operator std::string();
CHECK(input_args[3].is_string());
tensor_name = input_args[3].operator std::string();
}
auto out = pe::Identity(bias_tensor, tensor_name).front(); auto out = pe::Identity(bias_tensor, tensor_name).front();
auto stages = CreateStages( auto stages = CreateStages(
{lhs.as_tensor_ref(), rhs.as_tensor_ref(), bias_tensor}); {lhs.as_tensor_ref(), rhs.as_tensor_ref(), bias_tensor});
...@@ -849,12 +773,9 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform( ...@@ -849,12 +773,9 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
Expr A = input_args[0]; Expr A = input_args[0];
CHECK(A.as_tensor()); CHECK(A.as_tensor());
std::string tensor_name = UniqName("layout_transform_output"); CHECK_EQ(input_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(input_args[1].is_string());
CHECK_EQ(input_args.size(), 2); std::string tensor_name = input_args[1].operator std::string();
CHECK(input_args[1].is_string());
tensor_name = input_args[1].operator std::string();
}
auto out = pe::LayoutTransform( auto out = pe::LayoutTransform(
A.as_tensor_ref(), src_layout, dst_layout, tensor_name); A.as_tensor_ref(), src_layout, dst_layout, tensor_name);
...@@ -865,53 +786,31 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform( ...@@ -865,53 +786,31 @@ std::shared_ptr<OpStrategy> StrategyForLayoutTransform(
*ret = CINNValuePack{res}; *ret = CINNValuePack{res};
}); });
framework::CINNSchedule layout_transform_schedule( framework::CINNSchedule layout_transform_schedule([=](lang::Args args,
[=](lang::Args args, lang::RetValue *ret) { lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) { CHECK(!args.empty()) << "The input argument of CublasGemm schedule "
CHECK(!args.empty()) << "The input argument of CublasGemm schedule " "is empty! Please check.";
"is empty! Please check."; CINNValuePack arg_pack = args[0];
CINNValuePack arg_pack = args[0]; std::vector<Expr> vec_ast;
std::vector<Expr> vec_ast; for (int i = 0; i < arg_pack.size(); i++) {
for (int i = 0; i < arg_pack.size(); i++) { if (arg_pack[i].is_expr()) {
if (arg_pack[i].is_expr()) { Expr temp = arg_pack[i];
Expr temp = arg_pack[i]; vec_ast.emplace_back(temp);
vec_ast.emplace_back(temp); }
} }
} CHECK(!vec_ast.empty());
CHECK(!vec_ast.empty()); ir::ModuleExpr mod_expr(vec_ast);
ir::ModuleExpr mod_expr(vec_ast); ir::IRSchedule ir_sch(mod_expr);
ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs();
ir_sch.MergeExprs();
if (target.arch == Target::Arch::X86) {
if (target.arch == Target::Arch::X86) { pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target);
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target); } else {
} else { CINN_NOT_IMPLEMENTED
CINN_NOT_IMPLEMENTED }
} std::vector<CINNValue> res{CINNValue(ir_sch.GetModule().GetExprs().at(0))};
std::vector<CINNValue> res{ *ret = CINNValuePack{res};
CINNValue(ir_sch.GetModule().GetExprs().at(0))}; });
*ret = CINNValuePack{res};
} else {
CHECK(!args.empty()) << "The input argument of layout_transform "
"schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
CHECK_EQ(arg_pack.size(), 2UL);
Expr out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(out.as_tensor());
auto tensor_out = out.as_tensor_ref();
std::vector<int> out_shape;
for (auto shape : tensor_out->shape) {
out_shape.push_back(shape.as_int32());
}
if (target.arch == Target::Arch::X86) {
pe::ScheduleInjectiveCPU(stages[tensor_out], out_shape, target);
} else {
CINN_NOT_IMPLEMENTED
}
*ret = arg_pack;
}
});
auto strategy = std::make_shared<framework::OpStrategy>(); auto strategy = std::make_shared<framework::OpStrategy>();
CHECK(out_type.size()) CHECK(out_type.size())
...@@ -996,12 +895,9 @@ std::shared_ptr<OpStrategy> StrategyForReverse( ...@@ -996,12 +895,9 @@ std::shared_ptr<OpStrategy> StrategyForReverse(
Expr A = input_args[0]; Expr A = input_args[0];
CHECK(A.as_tensor()); CHECK(A.as_tensor());
std::string tensor_name = UniqName("Reverse_output"); CHECK_EQ(input_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(input_args[1].is_string());
CHECK_EQ(input_args.size(), 2); std::string tensor_name = input_args[1].operator std::string();
CHECK(input_args[1].is_string());
tensor_name = input_args[1].operator std::string();
}
auto out = pe::Reverse(A.as_tensor_ref(), axis, tensor_name); auto out = pe::Reverse(A.as_tensor_ref(), axis, tensor_name);
auto stages = CreateStages({A.as_tensor_ref(), out}); auto stages = CreateStages({A.as_tensor_ref(), out});
...@@ -1113,12 +1009,9 @@ std::shared_ptr<OpStrategy> StrategyForTranspose( ...@@ -1113,12 +1009,9 @@ std::shared_ptr<OpStrategy> StrategyForTranspose(
<< "at least one input tensor for transpose compute\n"; << "at least one input tensor for transpose compute\n";
Expr A = input_args[0]; Expr A = input_args[0];
CHECK(A.as_tensor()); CHECK(A.as_tensor());
std::string tensor_name = UniqName("Transpose_output"); CHECK_EQ(input_args.size(), 2);
if (FLAGS_cinn_ir_schedule) { CHECK(input_args[1].is_string());
CHECK_EQ(input_args.size(), 2); std::string tensor_name = input_args[1].operator std::string();
CHECK(input_args[1].is_string());
tensor_name = input_args[1].operator std::string();
}
auto out = pe::Transpose(A.as_tensor_ref(), axis, tensor_name); auto out = pe::Transpose(A.as_tensor_ref(), axis, tensor_name);
auto stages = CreateStages({out}); auto stages = CreateStages({out});
...@@ -1236,12 +1129,9 @@ std::shared_ptr<OpStrategy> StrategyForGather( ...@@ -1236,12 +1129,9 @@ std::shared_ptr<OpStrategy> StrategyForGather(
Expr index = input_args[1]; Expr index = input_args[1];
CHECK(index.as_tensor()); CHECK(index.as_tensor());
std::string tensor_name = UniqName("gather_output"); CHECK_EQ(input_args.size(), 3U);
if (FLAGS_cinn_ir_schedule) { CHECK(input_args[2].is_string());
CHECK_EQ(input_args.size(), 3U); std::string tensor_name = input_args[2].operator std::string();
CHECK(input_args[2].is_string());
tensor_name = input_args[2].operator std::string();
}
auto out = pe::Gather(x.as_tensor_ref(), auto out = pe::Gather(x.as_tensor_ref(),
index.as_tensor_ref(), index.as_tensor_ref(),
...@@ -1335,12 +1225,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAssign( ...@@ -1335,12 +1225,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAssign(
auto stages = CreateStages({tensor_input, tensor_updates, tensor_index}); auto stages = CreateStages({tensor_input, tensor_updates, tensor_index});
std::string tensor_name = UniqName("scatter_assign_output"); CHECK_EQ(arg_pack.size(), 4U);
if (FLAGS_cinn_ir_schedule) { CHECK(arg_pack[3].is_string());
CHECK_EQ(arg_pack.size(), 4U); std::string tensor_name = arg_pack[3].operator std::string();
CHECK(arg_pack[3].is_string());
tensor_name = arg_pack[3].operator std::string();
}
auto out = pe::ScatterAssign( auto out = pe::ScatterAssign(
tensor_input, tensor_updates, tensor_index, target, axis, tensor_name); tensor_input, tensor_updates, tensor_index, target, axis, tensor_name);
...@@ -1462,12 +1349,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAdd( ...@@ -1462,12 +1349,9 @@ std::shared_ptr<OpStrategy> StrategyForScatterAdd(
auto stages = CreateStages({tensor_input, tensor_updates, tensor_index}); auto stages = CreateStages({tensor_input, tensor_updates, tensor_index});
std::string tensor_name = UniqName("scatter_add_output"); CHECK_EQ(arg_pack.size(), 4U);
if (FLAGS_cinn_ir_schedule) { CHECK(arg_pack[3].is_string());
CHECK_EQ(arg_pack.size(), 4U); std::string tensor_name = arg_pack[3].operator std::string();
CHECK(arg_pack[3].is_string());
tensor_name = arg_pack[3].operator std::string();
}
auto out = pe::ScatterAdd( auto out = pe::ScatterAdd(
tensor_input, tensor_updates, tensor_index, target, axis, tensor_name); tensor_input, tensor_updates, tensor_index, target, axis, tensor_name);
...@@ -1617,12 +1501,9 @@ std::shared_ptr<OpStrategy> StrategyForSlice( ...@@ -1617,12 +1501,9 @@ std::shared_ptr<OpStrategy> StrategyForSlice(
CHECK(A_expr.as_tensor()); CHECK(A_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref(); ir::Tensor A = A_expr.as_tensor_ref();
std::string tensor_name = UniqName("Slice_output"); CHECK_EQ(arg_pack.size(), 2U);
if (FLAGS_cinn_ir_schedule) { CHECK(arg_pack[1].is_string());
CHECK_EQ(arg_pack.size(), 2U); std::string tensor_name = arg_pack[1].operator std::string();
CHECK(arg_pack[1].is_string());
tensor_name = arg_pack[1].operator std::string();
}
auto out = pe::Slice( auto out = pe::Slice(
A, starts, axes, strides, decrease_axis, output_shape, tensor_name); A, starts, axes, strides, decrease_axis, output_shape, tensor_name);
...@@ -1854,12 +1735,9 @@ std::shared_ptr<OpStrategy> StrategyForSliceAssign( ...@@ -1854,12 +1735,9 @@ std::shared_ptr<OpStrategy> StrategyForSliceAssign(
Expr assign = arg_pack[1]; Expr assign = arg_pack[1];
CHECK(assign.as_tensor()); CHECK(assign.as_tensor());
std::string tensor_name = UniqName("slice_assign_output"); CHECK_EQ(arg_pack.size(), 3U);
if (FLAGS_cinn_ir_schedule) { CHECK(arg_pack[2].is_string());
CHECK_EQ(arg_pack.size(), 3U); std::string tensor_name = arg_pack[2].operator std::string();
CHECK(arg_pack[2].is_string());
tensor_name = arg_pack[2].operator std::string();
}
auto out = pe::SliceAssign(input.as_tensor_ref(), auto out = pe::SliceAssign(input.as_tensor_ref(),
assign.as_tensor_ref(), assign.as_tensor_ref(),
......
...@@ -86,40 +86,18 @@ TEST(SliceAssign, SliceAssign_Op) { ...@@ -86,40 +86,18 @@ TEST(SliceAssign, SliceAssign_Op) {
std::string func_name = "slice_assign"; std::string func_name = "slice_assign";
if (FLAGS_cinn_ir_schedule) { std::string out_name = "output";
std::string out_name = "output"; common::CINNValuePack cinn_input =
common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(input.tensor()),
common::CINNValuePack{{common::CINNValue(input.tensor()), common::CINNValue(assign.tensor()),
common::CINNValue(assign.tensor()), common::CINNValue(out_name)}};
common::CINNValue(out_name)}}; std::vector<std::string> input_output_names{"input", "assign", out_name};
std::vector<std::string> input_output_names{"input", "assign", out_name};
auto funcs = framework::GetFuncFromImpl(
auto funcs = framework::GetFuncFromImpl( impl, cinn_input, inputs, input_output_names, func_name, target);
impl, cinn_input, inputs, input_output_names, func_name, target);
for (auto func : funcs) {
for (auto func : funcs) { LOG(INFO) << "Test Operator_BroadcastTo's Strategy, func is :\n" << func;
LOG(INFO) << "Test Operator_BroadcastTo's Strategy, func is :\n" << func;
}
} else {
common::CINNValuePack cinn_input =
common::CINNValuePack{{common::CINNValue(input.tensor()),
common::CINNValue(assign.tensor())}};
common::CINNValuePack rets = impl->fcompute(cinn_input);
rets = impl->fschedule(rets);
// the last element is a StageMap
for (int i = 0; i < rets->size() - 1; i++) {
Expr temp = rets[i];
if (!temp.as_tensor_ref()->buffer.defined()) {
inputs.push_back(temp.as_tensor_ref());
}
}
auto func = lang::LowerVec(
"slice_assign", rets.back(), inputs, {}, {}, nullptr, target);
for (auto& f : func) {
LOG(INFO) << "Test Strategy Codegen:\n" << f;
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册