未验证 提交 d38cd6ce 编写于 作者: B BiynXu 提交者: GitHub

[CINN][Fix] change Flatten to Fuse (#56719)

Change FlattenLoops in the elementwise schedule to Fuse
上级 a28e6f63
......@@ -46,15 +46,15 @@ void IRElementwiseSchedule(ir::IRSchedule &ir_sch, // NOLINT
<< ir_sch.GetModule().GetExprs().at(0);
if (target == common::DefaultNVGPUTarget()) {
auto blocks = ir_sch.GetAllBlocks();
ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), true);
std::vector<ir::Expr> loops = ir_sch.GetLoops(blocks[0]);
ir::Expr loop = ir_sch.Fuse(loops);
auto loops = ir_sch.GetLoops(blocks[0]);
auto size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
if (size <= target.max_num_threads()) {
ir_sch.Bind(loops[0], "threadIdx.x");
ir_sch.Bind(loop, "threadIdx.x");
} else {
auto splited = ir_sch.Split(loops[0], {-1, target.max_num_threads()});
auto splited = ir_sch.Split(loop, {-1, target.max_num_threads()});
ir_sch.Bind(splited[0], "blockIdx.x");
ir_sch.Bind(splited[1], "threadIdx.x");
}
......@@ -74,15 +74,15 @@ void IRInjectiveSchedule(ir::IRSchedule &ir_sch, // NOLINT
<< ir_sch.GetModule().GetExprs().at(0);
if (target == common::DefaultNVGPUTarget()) {
auto blocks = ir_sch.GetAllBlocks();
ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), false);
std::vector<ir::Expr> loops = ir_sch.GetLoops(blocks[0]);
ir::Expr loop = ir_sch.Fuse(loops);
auto loops = ir_sch.GetLoops(blocks[0]);
auto size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
if (size <= target.max_num_threads()) {
ir_sch.Bind(loops[0], "threadIdx.x");
ir_sch.Bind(loop, "threadIdx.x");
} else {
auto splited = ir_sch.Split(loops[0], {-1, target.max_num_threads()});
auto splited = ir_sch.Split(loop, {-1, target.max_num_threads()});
ir_sch.Bind(splited[0], "blockIdx.x");
ir_sch.Bind(splited[1], "threadIdx.x");
}
......
......@@ -1061,7 +1061,7 @@ set_tests_properties(
PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_nn_grad PROPERTIES TIMEOUT 120)
set_tests_properties(test_program_prune_backward PROPERTIES TIMEOUT 120)
set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 1000)
set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 250)
set_tests_properties(test_imperative_optimizer_v2 PROPERTIES TIMEOUT 250)
set_tests_properties(test_pool2d_op PROPERTIES TIMEOUT 120)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册