diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc index a4e54c0731987ee43aafca3ab532f7012c688d18..609550e4fe4d777fbb5dd912e2de0acd51f684da 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc @@ -71,7 +71,6 @@ TEST(AutoInline, SingleLoopInline) { nullptr, target, true); - VLOG(6) << "Expr after lowering:"; VLOG(6) << funcs[0]->body; @@ -170,7 +169,9 @@ TEST(AutoInline, AddReluInline) { EXPECT_EQ(graph->fusion_groups.size(), 1UL); std::vector funcs = - op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]); + op_lowerer->Lower(graph->fusion_groups[0], + /*apply_op_schedule = */ false, + /*apply_group_schedule=*/false); VLOG(6) << "Expr before auto inline: " << funcs[0]->body; diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc index 10c0ccc73489b263f7ef99cea4e6ae75beb47192..c41d9171c835f5a55731ccfca704b6a81ab20f79 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc @@ -388,9 +388,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) { TEST_F(TestMultiLevelTiling, Pool2d) { default_input_names = {"input"}; - default_output_names = {"var_0"}; - std::vector input_shape{2, 8, 16, 16}; - std::vector output_shape{2, 8, 8, 8}; + default_output_names = {"var_0", "pad_temp_0"}; + std::vector> input_shapes{{2, 8, 16, 16}}; + std::vector> output_shapes{{2, 8, 8, 8}, {2, 8, 18, 18}}; std::string pooling_type = "max"; std::vector ksize{3, 3}; std::vector strides{2, 2}; @@ -402,7 +402,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) { bool adaptive = false; std::string padding_algorithm = "EXPLICIT"; frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build( - {{"input", input_shape}}, + {{"input", input_shapes[0]}}, {{"pool_type", pooling_type}, {"kernel_size", ksize}, {"stride_size", strides}, @@ -440,85 +440,82 @@ TEST_F(TestMultiLevelTiling, Pool2d) { { ScheduleBlock(root) { - serial for (i, 0, 2) { - serial for (j, 0, 8) + serial for (i, 0, 2) { - serial for (k, 0, 18) + serial for (j, 0, 8) { - serial for (a, 0, 18) + serial for (k, 0, 18) { - ScheduleBlock(pad_temp_0) + serial for (a, 0, 18) { - i0, i1, i2, i3 = axis.bind(i, j, k, a) - pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f) + ScheduleBlock(pad_temp_0) + { + i0, i1, i2, i3 = axis.bind(i, j, k, a) + { + pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f) + } + } } } } } - } - } -} -} // end Expr 0 -Expr 1 { -{ - ScheduleBlock(root_0) - { - { - thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16) { - thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4) + thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16) { - serial for (i_1, 0, 1) + thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4) { - serial for (j_1, 0, 4) + serial for (i_1, 0, 1) { - serial for (k_1, 0, 1) + serial for (j_1, 0, 4) { - serial for (a_1, 0, 4) + serial for (k_1, 0, 1) { - ScheduleBlock(var_0__reduce_init) + serial for (a_1, 0, 4) { - i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)) + ScheduleBlock(var_0__reduce_init) { - var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f + i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)) + { + var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f + } } } } } } - } - { - serial for (kernel_idx, 0, 3) { - serial for (kernel_idx_0, 0, 3) + serial for (kernel_idx, 0, 3) { - serial for (ax0_ax1_ax2_ax3_fused, 0, 28) + serial for (kernel_idx_0, 0, 3) { - ScheduleBlock(pad_temp_0_shared_temp_buffer) + serial for (ax0_ax1_ax2_ax3_fused, 0, 28) { - v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0))) - attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0) + ScheduleBlock(pad_temp_0_shared_temp_buffer) { - pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3] + v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0))) + attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0) + { + pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3] + } } } - } - serial for (i_1, 0, 1) - { - serial for (j_1, 0, 4) + serial for (i_1, 0, 1) { - serial for (k_1, 0, 1) + serial for (j_1, 0, 4) { - serial for (a_1, 0, 4) + serial for (k_1, 0, 1) { - ScheduleBlock(var_0_local_temp_buffer) + serial for (a_1, 0, 4) { - i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0) - read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)]) - write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)]) + ScheduleBlock(var_0_local_temp_buffer) { - var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))]) + i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0) + read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)]) + write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)]) + { + var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))]) + } } } } @@ -526,21 +523,21 @@ Expr 1 { } } } - } - serial for (ax0_0, 0, 1) - { - serial for (ax1_0, 0, 4) + serial for (ax0_0, 0, 1) { - serial for (ax2_0, 0, 1) + serial for (ax1_0, 0, 4) { - serial for (ax3_0, 0, 4) + serial for (ax2_0, 0, 1) { - ScheduleBlock(var_0) + serial for (ax3_0, 0, 4) { - v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0)) - attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0) + ScheduleBlock(var_0) { - var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3] + v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0)) + attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0) + { + var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3] + } } } } @@ -553,7 +550,7 @@ Expr 1 { } } } -} // end Expr 1 +} // end Expr 0 )ROC"; ASSERT_EQ(ir, expected_ir); @@ -569,8 +566,8 @@ Expr 1 { pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))), default_input_names, default_output_names, - {input_shape}, - {output_shape}, + input_shapes, + output_shapes, target_); } diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc index e8ed904066d17cf5a57dd5d6a4d105d5ad3f931d..19a9534dfd69459a2813f988a60a5370dd573827 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc @@ -63,12 +63,10 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule( absl::flat_hash_map>("infershape"); hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_); - if (apply_manual_schedule) { - lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front()); - } else { - lowered_funcs_ = - op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front()); - } + lowered_funcs_ = + op_lowerer.Lower(graph->fusion_groups.front(), + /*apply_op_schedule = */ apply_manual_schedule, + /*apply_group_schedule = */ apply_manual_schedule); CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty"; std::vector bodys; diff --git a/paddle/cinn/auto_schedule/task/tune_task.cc b/paddle/cinn/auto_schedule/task/tune_task.cc index a6c11a4e4d58b267e00472a29758e7e9abdcb025..091a45b1c304abb373cd90f8b121f2ccc7199e00 100644 --- a/paddle/cinn/auto_schedule/task/tune_task.cc +++ b/paddle/cinn/auto_schedule/task/tune_task.cc @@ -39,7 +39,8 @@ void TuneTask::Initialize( op_lowerer = lower_handler; // Set lowered_funcs and analyze output names. - this->lowered_funcs = op_lowerer->LowerWithoutSchedule(subgraph); + this->lowered_funcs = op_lowerer->Lower( + subgraph, /*apply_op_schedule = */ false, /*apply_group_schedule=*/false); this->output_names = GetOutputNamesFromLoweredFunc(this->lowered_funcs); this->serialized_key = SerializeToString(shape_dict, dtype_dict); } diff --git a/paddle/cinn/auto_schedule/tests/performance_comparison_test.cc b/paddle/cinn/auto_schedule/tests/performance_comparison_test.cc index 79b4dc95d180c9846b63cd75d26e3e38ab6a8055..628d8909f270a38d468ea77f19f33a78ec74c0df 100644 --- a/paddle/cinn/auto_schedule/tests/performance_comparison_test.cc +++ b/paddle/cinn/auto_schedule/tests/performance_comparison_test.cc @@ -157,7 +157,9 @@ class PerformanceTester : public ::testing::Test { for (auto group : graph->fusion_groups) { compile_options.lowered_funcs.push_back( - op_lowerer->LowerWithoutSchedule(group)); + op_lowerer->Lower(group, + /*apply_op_schedule = */ false, + /*apply_group_schedule=*/false)); } VLOG(3) << "===========================No Schedule LoweredFunc " diff --git a/paddle/cinn/hlir/framework/op_lowering.cc b/paddle/cinn/hlir/framework/op_lowering.cc index bf6099cc9a6bf4caacd2c6510a8376bb268b6328..d26e40891a7b31183d30195aefb39c57231d802e 100644 --- a/paddle/cinn/hlir/framework/op_lowering.cc +++ b/paddle/cinn/hlir/framework/op_lowering.cc @@ -45,7 +45,9 @@ OpLowerer::OpLowerer( const Target& target) : type_dict_(type_dict), shape_dict_(shape_dict), target_(target) {} -std::vector OpLowerer::Lower(GroupPtr& group) { // NOLINT +std::vector OpLowerer::Lower(const GroupPtr& group, + bool apply_op_schedule, + bool apply_group_schedule) { VLOG(3) << "Lowering Group : " << group->group_id << " , Op Pattern : " << group->op_pattern_kind; group->input_names.clear(); @@ -55,13 +57,22 @@ std::vector OpLowerer::Lower(GroupPtr& group) { // NOLINT case framework::kElementWise: case framework::kBroadcast: case framework::kInjective: - return IRLowerOp(&OpLowerer::IRElementwiseCompute, group); + return LowerGroup(group, + apply_op_schedule, + apply_group_schedule, + &OpLowerer::ElementwiseScheduleDetermineFunction); case framework::kReduction: - return IRLowerOp(&OpLowerer::IRReduceCompute, group); + return LowerGroup(group, + apply_op_schedule, + apply_group_schedule, + &OpLowerer::ReduceScheduleDetermineFunction); case framework::kOutFusible: LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!"; case framework::kNonFusible: - return IRLowerNonFusibleOp(group, /*apply_impl_schedule = */ true); + return LowerGroup(group, + apply_op_schedule, + apply_group_schedule, + &OpLowerer::NonFusibleScheduleDetermineFunction); default: LOG(FATAL) << "Group Pattern Kind Is Unknown!"; } @@ -70,532 +81,329 @@ std::vector OpLowerer::Lower(GroupPtr& group) { // NOLINT } } -std::vector OpLowerer::LowerWithoutSchedule(GroupPtr& group) { - VLOG(3) << "Lowering Group : " << group->group_id - << " , Op Pattern : " << group->op_pattern_kind; - if (FLAGS_cinn_ir_schedule) { - switch (group->op_pattern_kind) { - case framework::kElementWise: - case framework::kBroadcast: - case framework::kInjective: - return IRLowerOpWithoutSchedule(&OpLowerer::IRElementwiseCompute, - group); - case framework::kReduction: - return IRLowerOpWithoutSchedule(&OpLowerer::IRReduceCompute, group); - case framework::kOutFusible: - LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!"; - case framework::kNonFusible: - return IRLowerNonFusibleOp(group, /*apply_impl_schedule = */ false); - default: - LOG(FATAL) << "Group Pattern Kind kNonFusible Is Not Implemented!"; - } - } else { - LOG(FATAL) << "Previous IR Schedule Is Not Implemented!"; - } +bool OpLowerer::ElementwiseScheduleDetermineFunction(Node* node) { + return true; } -std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, - GroupPtr& group) { - poly::StageMap stages; - std::vector arg_tensors; - std::unordered_map tensor_map; - // do compute. +bool OpLowerer::ReduceScheduleDetermineFunction(Node* node) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + return op_pattern_dict[node->op()] == framework::kReduction; +} + +bool OpLowerer::NonFusibleScheduleDetermineFunction(Node* node) { return true; } + +std::vector OpLowerer::LowerGroup( + const GroupPtr& group, + bool apply_op_schedule, + bool apply_group_schedule, + ScheduleDetermineFunction schedule_determine_func) { + // 1.Do compute, lower and schedule for each op. VLOG(3) << "group->fused_sub_groups.size() is : " << group->fused_sub_groups.size(); - std::vector ast_exprs; - if (group->fused_sub_groups.size() == 0) { - ast_exprs = (this->*compute)(stages, - arg_tensors, - tensor_map, - group, - group, - /*apply_impl_schedule = */ true); - } else { - for (auto& sub_group : group->fused_sub_groups) { - auto exprs = (this->*compute)(stages, - arg_tensors, - tensor_map, - group, - sub_group, - /*apply_impl_schedule = */ true); - ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end()); - } + std::vector nodes = group->CollectNodes(); + if (nodes.size() == 1 && nodes[0]->op()->name == "custom_call") { + return LowerCustomCall(group); } - ir::ModuleExpr mod_expr(ast_exprs); + std::vector group_func_arg_tensors; + std::unordered_map tensor_map; + bool do_op_schedule = apply_group_schedule || apply_op_schedule; + std::vector func_bodies = LowerOps(nodes, + do_op_schedule, + schedule_determine_func, + &group_func_arg_tensors, + &tensor_map); + + // 2.Do group schedule. + ir::ModuleExpr mod_expr(func_bodies); ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); - - Node* first = nullptr; - Node* second = nullptr; - - VLOG(3) << "Before IRLowerOp schedule, ir is: \n" - << ir_sch.GetModule().GetExprs().at(0); - // do schedule. - IRSchedule(ir_sch, group, tensor_map); - VLOG(3) << "After IRLowerOp schedule, ir is: \n" - << ir_sch.GetModule().GetExprs().at(0); - // function args - group->input_names.clear(); - std::vector func_args; - for (auto& args : arg_tensors) { - // input node data name. - group->input_names.push_back(args->name); - // input args - func_args.emplace_back(args->buffer, ir::Argument::IO::kInput); + VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + if (apply_group_schedule) { + DoGroupSchedule(ir_sch, group, tensor_map); + VLOG(3) << "After group schedule, ir is: \n" + << ir_sch.GetModule().GetExprs().at(0); } - group->output_names.clear(); - for (auto& node : group->output_nodes) { - // output node data name. - for (auto node_data : GetAllNodeData(node)) { - group->output_names.push_back(node_data->id()); - } - // collect all output tensor. - std::string post = ""; - std::string prefix = GetNodeData(node)->id(); - for (int idx = 0; idx < 1; ++idx) { - CHECK(tensor_map.count(prefix)) << "Can't find output tensor " << prefix; - if (!tensor_map.count(prefix + post)) { - break; - } - auto tensor = tensor_map[prefix + post]; - arg_tensors.push_back(tensor); - // output args - func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); - // update post - post = "_" + std::to_string(idx); - } - } - auto func_body = ir_sch.GetModule().GetExprs().at(0); -#ifdef CINN_WITH_CUDA - optim::OptimizeExprGPU(&(func_body)); -#endif - - auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body); - auto func = ir::_LoweredFunc_::Make(group->GetFuncName(), - func_args, - ir_sch.GetModule().GetExprs().at(0), - temp_buffers); - func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); - return {func}; + // 3.Do post-processing, + // including preparing function args and temporary variables, + // applying low-level optimization passes, etc. + return PostProcess( + group, tensor_map, do_op_schedule, &ir_sch, &group_func_arg_tensors); } -std::vector OpLowerer::IRLowerOpWithoutSchedule( - IRComputeFunction compute, GroupPtr& group) { - poly::StageMap stages; - std::vector arg_tensors; +std::vector OpLowerer::LowerCustomCall(const GroupPtr& group) { + std::vector nodes = group->CollectNodes(); + CHECK_EQ(nodes.size(), 1); + Node* node = nodes[0]; + std::vector op_func_arg_tensors; std::unordered_map tensor_map; - // do compute. - VLOG(3) << "group->fused_sub_groups.size() is : " - << group->fused_sub_groups.size(); - std::vector ast_exprs; - if (group->fused_sub_groups.size() == 0) { - ast_exprs = (this->*compute)(stages, - arg_tensors, - tensor_map, - group, - group, - /*apply_impl_schedule = */ false); - } else { - for (auto& sub_group : group->fused_sub_groups) { - auto exprs = (this->*compute)(stages, - arg_tensors, - tensor_map, - group, - sub_group, - /*apply_impl_schedule = */ false); - ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end()); + for (auto& node_data : GetInputNodeData(node)) { + CHECK(node_data); + ir::Tensor tensor; + if (!tensor_map.count(node_data->id())) { + tensor = GetTensor(node_data, this->type_dict_, this->shape_dict_); + // record tensor. + tensor_map[node_data->id()] = tensor; + // input name. + group->input_names.push_back(node_data->id()); + } else { + tensor = tensor_map[node_data->id()]; } + op_func_arg_tensors.push_back(tensor); } - ir::ModuleExpr mod_expr(ast_exprs); - ir::IRSchedule ir_sch(mod_expr); - ir_sch.MergeExprs(); - VLOG(3) << "After IRLowerOp compute, ir is: \n" - << ir_sch.GetModule().GetExprs().at(0); - // function args + std::vector out_types; + std::vector> out_shapes; + auto node_datas = GetAllNodeData(node); + for (auto node_data : node_datas) { + group->output_names.push_back(node_data->id()); + out_types.push_back(this->type_dict_.at(node_data->id())); + out_shapes.push_back(this->shape_dict_.at(node_data->id())); + } + auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); + auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()]( + node->attrs, op_func_arg_tensors, out_types, out_shapes, target_)); + std::string external_api; + if (node->attrs.attr_store.count("custom_call")) { + external_api = + absl::get(node->attrs.attr_store.at("custom_call")); + } else { + external_api = ExternalApiRegistry::Global()->GetExternalApi(node, target_); + } + std::vector compute_args = { + common::CINNValue(group->GetFuncName()), common::CINNValue(external_api)}; + common::CINNValuePack pack = + impl->fcompute(common::CINNValuePack{compute_args}); + CHECK_EQ(pack.size(), 1UL); + // reset input names as extern api input args can't be remove duplicate. + group->input_names.clear(); + for (auto& inode : node->inlinks_in_order()) { + group->input_names.push_back(inode->source()->as()->id()); + } + return {pack[0].operator ir::Expr().as_lowered_func_ref()}; +} + +std::vector OpLowerer::PostProcess( + const GroupPtr& group, + const std::unordered_map& tensor_map, + bool done_op_schedule, + ir::IRSchedule* ir_sch, + std::vector* group_func_arg_tensors) { + // 1.Prepare function args group->input_names.clear(); - std::vector func_args; - for (auto& args : arg_tensors) { + std::vector group_func_args; + std::unordered_set arg_name_set; + for (auto& arg_tensor : *group_func_arg_tensors) { // input node data name. - group->input_names.push_back(args->name); + group->input_names.push_back(arg_tensor->name); // input args - func_args.emplace_back(args->buffer, ir::Argument::IO::kInput); + group_func_args.emplace_back(arg_tensor->buffer, ir::Argument::IO::kInput); + arg_name_set.insert(arg_tensor->buffer->name); } group->output_names.clear(); for (auto& node : group->output_nodes) { - // output node data name. - for (auto node_data : GetAllNodeData(node)) { - group->output_names.push_back(node_data->id()); - } // collect all output tensor. - std::string post = ""; - std::string prefix = GetNodeData(node)->id(); - for (int idx = 0; idx < 1; ++idx) { - CHECK(tensor_map.count(prefix)) << "Can't find output tensor " << prefix; - if (!tensor_map.count(prefix + post)) { - break; + for (auto node_data : GetAllNodeData(node)) { + std::string output_node_data_name = node_data->id(); + group->output_names.push_back(output_node_data_name); + // CHECK(tensor_map.count(output_node_data_name)) << "Can't find output + // tensor " << output_node_data_name; + if (tensor_map.count(output_node_data_name) == 0) { + continue; + } + auto tensor = tensor_map.at(output_node_data_name); + if (arg_name_set.count(tensor->buffer->name) != 0) { + continue; } - auto tensor = tensor_map[prefix + post]; - arg_tensors.push_back(tensor); + // output arg tensors + group_func_arg_tensors->push_back(tensor); // output args - func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); - // update post - post = "_" + std::to_string(idx); + group_func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); + arg_name_set.insert(tensor->buffer->name); } } - std::unordered_set args_map; - for (auto arg : func_args) { - args_map.insert(arg.name()); - } + if (!done_op_schedule) { + std::unordered_set args_set; + for (auto arg : group_func_args) { + args_set.insert(arg.name()); + } - for (auto& tensor : tensor_map) { - if (args_map.count("_" + tensor.first)) { - continue; + for (auto& tensor_pair : tensor_map) { + if (args_set.count("_" + tensor_pair.second->name)) { + continue; + } + group_func_arg_tensors->push_back(tensor_pair.second); + // use the underlying tensor name to be consistent with the argument name + // in the lowered function + group->output_names.push_back(tensor_pair.second->name); + group_func_args.emplace_back(tensor_pair.second->buffer, + ir::Argument::IO::kOutput); } - arg_tensors.push_back(tensor.second); - // use the underlying tensor name to be consistent with the argument name in - // the lowered function - group->output_names.push_back(tensor.second->name); - func_args.emplace_back(tensor.second->buffer, ir::Argument::IO::kOutput); } - auto func_body = ir_sch.GetModule().GetExprs().at(0); + auto func_body = ir_sch->GetModule().GetExprs().at(0); #ifdef CINN_WITH_CUDA optim::OptimizeExprGPU(&(func_body)); #endif - auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body); + // 2.Prepare temp buffers + poly::StageMap stages; + auto temp_buffers = + lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body); + // 3.Building LoweredFunc auto func = ir::_LoweredFunc_::Make(group->GetFuncName(), - func_args, - ir_sch.GetModule().GetExprs().at(0), + group_func_args, + ir_sch->GetModule().GetExprs().at(0), temp_buffers); - func->PrepareBufferCastExprs(); + if (!done_op_schedule) { + func->PrepareBufferCastExprs(); + } + // 4.Apply low level pass func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); - return {func}; } -std::vector OpLowerer::IRElementwiseCompute( - poly::StageMap& stages, - std::vector& func_tensors, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group, - bool apply_impl_schedule) { - VLOG(2) << "ElementwiseCompute Group : " << sub_group->group_id; +std::vector OpLowerer::LowerOps( + const std::vector& nodes, + bool apply_op_schedule, + ScheduleDetermineFunction schedule_determine_func, + std::vector* group_func_arg_tensors, + std::unordered_map* tensor_map) { auto& strategy = Operator::GetAttrs("CINNStrategy"); - - std::vector ast_exprs; - for (auto& node : sub_group->nodes) { - VLOG(4) << "Lower op: " << node->op()->name; - auto node_data = GetNodeData(node); - CHECK_EQ(GetAllNodeData(node).size(), 1U); - std::vector cinn_inputs; - std::vector tensor_inputs = std::move(CollectInputTensor( - node, func_tensors, tensor_map, this->type_dict_, this->shape_dict_)); - for (auto& tensor : tensor_inputs) { - cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); - } - // set tensor name = node data name - cinn_inputs.push_back(common::CINNValue(node_data->id())); - + std::vector func_bodies; + for (Node* node : nodes) { + // 1.Select Op impl std::vector out_types; std::vector> out_shapes; - out_types.push_back(this->type_dict_.at(node_data->id())); - out_shapes.push_back(this->shape_dict_.at(node_data->id())); - auto impl = OpStrategy::SelectImpl(strategy[node->op()]( - node->attrs, tensor_inputs, out_types, out_shapes, this->target_)); - // do compute - common::CINNValuePack pack = - impl->fcompute(common::CINNValuePack{cinn_inputs}); - CHECK_EQ(pack.size(), 2U); - - Expr expr = pack[0]; - poly::StageMap node_stages = pack.back(); - tensor_inputs.push_back(expr.as_tensor_ref()); - tensor_map[node_data->id()] = expr.as_tensor_ref(); - - auto func = lang::LowerVec("fn_" + node->id(), - node_stages, - tensor_inputs, - {}, - {}, - nullptr, - this->target_, - true); - CHECK_EQ(func.size(), 1); - - if (apply_impl_schedule) { - std::vector schedule_inputs; - // collect tensor - for (int idx = 0; idx < pack.size() - 1; ++idx) { - CHECK(pack[idx].is_tensor()); - schedule_inputs.push_back(common::CINNValue(pack[idx])); - } - for (auto& f : func) { - schedule_inputs.push_back(common::CINNValue(f->body)); - } - // do ast tree schedule - common::CINNValuePack expr_pack = - impl->fschedule(common::CINNValuePack{schedule_inputs}); - - CHECK_EQ(expr_pack.size(), 1); - Expr ast_expr = expr_pack[0]; - ast_exprs.push_back(ast_expr); - } else { - ast_exprs.push_back(func[0]->body); + std::vector node_datas = GetAllNodeData(node); + for (const auto& node_data : node_datas) { + out_types.push_back(this->type_dict_.at(node_data->id())); + out_shapes.push_back(this->shape_dict_.at(node_data->id())); } - } - - return ast_exprs; -} - -std::vector OpLowerer::IRReduceCompute( - poly::StageMap& stages, - std::vector& func_args, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group, - bool apply_impl_schedule) { - VLOG(2) << "ReduceCompute Group : " << sub_group->group_id; - auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); - auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - - std::vector ast_exprs; - for (auto& node : sub_group->nodes) { - auto node_data = GetNodeData(node); - VLOG(3) << "In ReduceCompute, process node: " << node->id() - << " with op type: " << node->op()->name; - - std::vector cinn_inputs; - std::vector tensor_inputs = std::move(CollectInputTensor( - node, func_args, tensor_map, this->type_dict_, this->shape_dict_)); - for (auto& tensor : tensor_inputs) { - cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); - } - cinn_inputs.push_back(common::CINNValue(node_data->id())); - - std::vector out_types; - std::vector> out_shapes; - - out_types.push_back(this->type_dict_.at(node_data->id())); - out_shapes.push_back(this->shape_dict_.at(node_data->id())); - - auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()]( - node->attrs, tensor_inputs, out_types, out_shapes, target_)); - // do compute - common::CINNValuePack pack = - impl->fcompute(common::CINNValuePack{cinn_inputs}); - - CHECK_GE(pack.size(), 2UL); - CHECK_LE(pack.size(), 5UL); - poly::StageMap tmp_stages = pack.back(); - - std::string post = ""; - for (int idx = 0; idx < pack.size() - 1; ++idx) { - Expr expr = pack[idx]; - tensor_map[node_data->id() + post] = expr.as_tensor_ref(); - // As op may has more than 1 output tensor, using id + "_0"/"_1" as key. - post = "_" + std::to_string(idx); - - // Insert outout tensors - if (!expr.as_tensor_ref()->buffer.defined() || - this->target_ != common::DefaultNVGPUTarget()) { - tensor_inputs.push_back(expr.as_tensor_ref()); - } - } - auto func = lang::LowerVec("fn_" + node->id(), - tmp_stages, - tensor_inputs, - {}, - {}, - nullptr, - this->target_, - true); - - // node is kReduction - if (op_pattern_dict[node->op()] == framework::kReduction && - apply_impl_schedule) { - std::vector schedule_inputs; - // collect tensor - for (int idx = 0; idx < pack.size() - 1; ++idx) { - CHECK(pack[idx].is_tensor()); - schedule_inputs.push_back(common::CINNValue(pack[idx])); - } - for (auto& f : func) { - schedule_inputs.push_back(common::CINNValue(f->body)); - } - // do ast tree schedule - common::CINNValuePack expr_pack = - impl->fschedule(common::CINNValuePack{schedule_inputs}); - // ast tree after schedule. - Expr ast_expr = expr_pack[0]; - ast_exprs.push_back(ast_expr); - } else if (group->master_nodes.count(node)) { - // as master node should copy transform from reducer, left it to reduce - // schedule. - ast_exprs.push_back(func[0]->body); + std::vector op_func_arg_tensors = + std::move(CollectInputTensor(node, + this->type_dict_, + this->shape_dict_, + group_func_arg_tensors, + tensor_map)); + auto op_impl = + OpStrategy::SelectImpl(strategy[node->op()](node->attrs, + op_func_arg_tensors, + out_types, + out_shapes, + this->target_)); + + // 2.Perform the lower process of Op + std::vector funcs = + DoOpLower(op_impl, node, tensor_map, &op_func_arg_tensors); + + if (apply_op_schedule && (this->*schedule_determine_func)(node)) { + // 3.Perform the schedule of Op + func_bodies.push_back(DoOpSchedule(op_impl, op_func_arg_tensors, funcs)); } else { - ast_exprs.push_back(func[0]->body); + for (const ir::LoweredFunc& func : funcs) { + func_bodies.push_back(func->body); + } } } - return ast_exprs; + return func_bodies; } -std::vector OpLowerer::IRLowerNonFusibleOp( - GroupPtr& group, bool apply_impl_schedule) { - VLOG(3) << "LowerNonFusibleOp Group : " << group->group_id; - // get input tensor and output tensor - CHECK(group->nodes.size() || group->fused_sub_groups.size()); - auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); - auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - - auto node = group->fused_sub_groups.size() - ? group->fused_sub_groups[0]->nodes.front() - : group->nodes.front(); - VLOG(3) << "GetOpFunc of op " << node->id(); - std::vector inputs; +std::vector OpLowerer::DoOpLower( + std::shared_ptr op_impl, + Node* node, + std::unordered_map* tensor_map, + std::vector* op_func_arg_tensors) { + VLOG(4) << "Do lower with Compute, op: " << node->op()->name; std::vector cinn_inputs; - - std::vector args; - std::unordered_map tensor_map; - for (auto& node_data : GetInputNodeData(node)) { - CHECK(node_data); - ir::Tensor tensor; - if (!tensor_map.count(node_data->id())) { - tensor = GetTensor(node_data, this->type_dict_, this->shape_dict_); - // record tensor. - tensor_map[node_data->id()] = tensor; - // input name. - group->input_names.push_back(node_data->id()); - // input type. - args.emplace_back(tensor->buffer, ir::Argument::IO::kInput); - } else { - tensor = tensor_map[node_data->id()]; - } - inputs.push_back(tensor); - cinn_inputs.push_back(common::CINNValue(tensor)); + for (const ir::Tensor& tensor : *op_func_arg_tensors) { + cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); } - - std::vector out_types; - std::vector> out_shapes; - auto node_datas = GetAllNodeData(node); - for (auto node_data : node_datas) { - VLOG(3) << "cinn_inputs.push_back " << node_data->id(); - group->output_names.push_back(node_data->id()); - out_types.push_back(this->type_dict_.at(node_data->id())); - out_shapes.push_back(this->shape_dict_.at(node_data->id())); + // set tensor name = node data name + std::vector node_datas = GetAllNodeData(node); + for (const NodeData* node_data : node_datas) { cinn_inputs.push_back(common::CINNValue(node_data->id())); } - auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()]( - node->attrs, inputs, out_types, out_shapes, target_)); - // if node op is custom_call, apply custom_call compute. - if (node->op()->name == "custom_call") { - std::string external_api; - if (node->attrs.attr_store.count("custom_call")) { - external_api = - absl::get(node->attrs.attr_store.at("custom_call")); + // 1.Do compute + common::CINNValuePack pack = + op_impl->fcompute(common::CINNValuePack{cinn_inputs}); + + poly::StageMap tmp_stages = pack.back(); + std::string post = ""; + for (int idx = 0; idx < pack.size() - 1; ++idx) { + Expr expr = pack[idx]; + // Insert the output tensor defined by Compute into the tensor_map + if (pack.size() - 1 > node_datas.size()) { + // Some nodes may output multiple temp tensors in their Compute + // definition, but only one output node_data in the graph, and we use id + + // "_0"/"_1" as key. + (*tensor_map)[node_datas[0]->id() + post] = expr.as_tensor_ref(); + post = "_" + std::to_string(idx); } else { - external_api = - ExternalApiRegistry::Global()->GetExternalApi(node, target_); + // If the number of output tensors defined by Compute is less equal than + // the output node_data on the graph, then there is a one-to-one + // correspondence, and the redundant output node_data contact empty. + (*tensor_map)[node_datas[idx]->id()] = expr.as_tensor_ref(); } - std::vector compute_args = { - common::CINNValue(group->GetFuncName()), - common::CINNValue(external_api)}; - common::CINNValuePack pack = - impl->fcompute(common::CINNValuePack{compute_args}); - CHECK_EQ(pack.size(), 1UL); - // reset input names as extern api input args can't be remove duplicate. - group->input_names.clear(); - for (auto& inode : node->inlinks_in_order()) { - group->input_names.push_back(inode->source()->as()->id()); - } - return {pack[0].operator ir::Expr().as_lowered_func_ref()}; - } - common::CINNValuePack pack = - impl->fcompute(common::CINNValuePack{cinn_inputs}); - for (int i = 0; i < pack->size() - 1; i++) { - ir::Expr temp = pack[i]; - // checkout whether the tensor is with buffer. - if (!temp.as_tensor_ref()->buffer.defined() || + // Insert output tensors into function arg + if (!expr.as_tensor_ref()->buffer.defined() || this->target_ != common::DefaultNVGPUTarget()) { - inputs.push_back(temp.as_tensor_ref()); - temp.as_tensor_ref()->WithBuffer(); - args.emplace_back(temp.as_tensor_ref()->buffer, - ir::Argument::IO::kOutput); + op_func_arg_tensors->push_back(expr.as_tensor_ref()); + expr.as_tensor_ref()->WithBuffer(); } } - poly::StageMap stages = pack.back(); - auto func = lang::LowerVec(group->GetFuncName(), - stages, - inputs, - {}, - {}, - nullptr, - this->target_, - true); - - if (apply_impl_schedule) { - std::vector schedule_inputs; - // collect tensor - for (int idx = 0; idx < pack.size() - 1; ++idx) { - CHECK(pack[idx].is_tensor()); - schedule_inputs.push_back(common::CINNValue(pack[idx])); - } - for (auto& f : func) { - schedule_inputs.push_back(common::CINNValue(f->body)); - } - // do ast tree schedule - common::CINNValuePack expr_pack = - impl->fschedule(common::CINNValuePack{schedule_inputs}); - - ir::Expr func_body = expr_pack[0]; - std::vector input_output_nodes(group->input_names); - input_output_nodes.insert(input_output_nodes.end(), - group->output_names.begin(), - group->output_names.end()); - VLOG(6) << "func.size() = " << func.size() - << ", expr_pack.size() = " << expr_pack.size(); - VLOG(6) << "args.size() = " << args.size() - << ", input_output_nodes.size() = " << input_output_nodes.size(); - if (args.size() > input_output_nodes.size()) { - args = lang::GetArgs(func_body, input_output_nodes); - } - std::vector res; - for (int i = 0; i < expr_pack.size(); i++) { - ir::Expr func_body = expr_pack[0]; -#ifdef CINN_WITH_CUDA - optim::OptimizeExprGPU(&(func_body)); -#endif - auto temp_buffers = lang::GetTempBuffers(inputs, stages, func_body); - auto function = ir::_LoweredFunc_::Make( - group->GetFuncName(), args, func_body, temp_buffers); - res.push_back(function); - } - for (auto& i : res) { - i = optim::Optimize(Expr(i), target_, false).as_lowered_func_ref(); - } - return res; - } else { - for (auto& f : func) { -#ifdef CINN_WITH_CUDA - optim::OptimizeExprGPU(&(f->body)); -#endif - f = optim::Optimize(Expr(f), target_, false).as_lowered_func_ref(); - } - return func; + // 2.Do lower + std::vector funcs = lang::LowerVec("fn_" + node->id(), + tmp_stages, + *op_func_arg_tensors, + {}, + {}, + nullptr, + this->target_, + true); + VLOG(4) << "Lower op: " << node->op()->name << ", get " << funcs.size() + << " LoweredFunc:\n"; + + op_func_arg_tensors->clear(); + for (int idx = 0; idx < pack.size() - 1; ++idx) { + CHECK(pack[idx].is_tensor()); + op_func_arg_tensors->push_back( + pack[idx].operator ir::Expr().as_tensor_ref()); } + + return funcs; +} + +ir::Expr OpLowerer::DoOpSchedule( + std::shared_ptr op_impl, + const std::vector& op_func_arg_tensors, + const std::vector& lowered_funcs) { + VLOG(4) << "Do op schedule"; + std::vector schedule_inputs; + // 1.Collect tensors + for (const ir::Tensor& op_func_arg_tensor : op_func_arg_tensors) { + schedule_inputs.push_back(common::CINNValue(op_func_arg_tensor)); + } + // 2.Collect bodies to be scheduled + for (const ir::LoweredFunc& func : lowered_funcs) { + schedule_inputs.push_back(common::CINNValue(func->body)); + } + // 3.Do schedule on AST + common::CINNValuePack expr_pack = + op_impl->fschedule(common::CINNValuePack{schedule_inputs}); + VLOG(4) << "After op schedule: " << expr_pack[0].operator ir::Expr(); + + return expr_pack[0].operator ir::Expr(); } // group schedule -void OpLowerer::IRSchedule( +ir::Expr OpLowerer::DoGroupSchedule( ir::IRSchedule& ir_sch, const GroupPtr& group, const std::unordered_map& tensor_map) { @@ -698,6 +506,7 @@ void OpLowerer::IRSchedule( << ", ir is:\n" << ir_sch.GetModule().GetExprs().at(0); // if node is horizontal with reduce or node is reduce, loop assign + // // master. auto loops = ir_sch.GetLoops(GetNodeData(node)->id()); if (op_pattern_dict[node->op()] == framework::kElementWise) { @@ -788,6 +597,7 @@ void OpLowerer::IRSchedule( ir_sch, group, nodes_inline, nodes_set, this->shape_dict_, tensor_map); VLOG(4) << "After IRSchedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + return ir_sch.GetModule().GetExprs().at(0); } } // namespace framework diff --git a/paddle/cinn/hlir/framework/op_lowering.h b/paddle/cinn/hlir/framework/op_lowering.h index 5e909d1196bbc7c21891ef2f3843214e0bdcec90..6059b87ac44d4b5a0bee4895fb7e329e0db31931 100644 --- a/paddle/cinn/hlir/framework/op_lowering.h +++ b/paddle/cinn/hlir/framework/op_lowering.h @@ -39,46 +39,132 @@ using GroupPtr = std::shared_ptr; using common::Target; class OpLowerer; -typedef std::vector (OpLowerer::*IRComputeFunction)( - poly::StageMap&, - std::vector&, - std::unordered_map&, - const GroupPtr&, - const GroupPtr&, - bool); + +typedef bool (OpLowerer::*ScheduleDetermineFunction)(Node*); class OpLowerer { public: OpLowerer(const absl::flat_hash_map&, const absl::flat_hash_map&, const Target&); - std::vector Lower(GroupPtr& group); // NOLINT - std::vector LowerWithoutSchedule(GroupPtr& group); // NOLINT + + /** + * @brief Lower a group to CINN IR. + * @param group The group to be lowered. + * @param apply_op_schedule Whether to schedule at Op level. + * @param apply_group_schedule Whether to schedule at group level. + * @return The lowered funcs. + */ + std::vector Lower(const GroupPtr& group, + bool apply_op_schedule = true, + bool apply_group_schedule = true); private: - std::vector IRLowerOp(IRComputeFunction, GroupPtr&); - std::vector IRLowerNonFusibleOp(GroupPtr&, bool); - std::vector IRLowerOpWithoutSchedule(IRComputeFunction, - GroupPtr&); -#define DEFINE_IR_COMPUTE(type) \ - std::vector IR##type##Compute( \ - poly::StageMap& stages, \ - std::vector& func_args, \ - std::unordered_map& tensor_map, \ - const GroupPtr& group, \ - const GroupPtr& sub_group, \ - bool apply_impl_schedule = false); - - // compute and schedule - DEFINE_IR_COMPUTE(Elementwise); - DEFINE_IR_COMPUTE(Reduce); - DEFINE_IR_COMPUTE(OutEWiseFusable); - - void IRSchedule( + /** + * @brief Lower a group to CINN IR. + * @param group The group to be lowered. + * @param apply_op_schedule Whether to schedule at Op level. + * @param apply_group_schedule Whether to schedule at group level. + * @param schedule_determine_func Function used to determine which Ops to + * schedule. + * @return The lowered funcs. + */ + std::vector LowerGroup( + const GroupPtr& group, + bool apply_op_schedule, + bool apply_group_schedule, + ScheduleDetermineFunction schedule_determine_func); + + /** + * @brief Lower a group composed of CustomCall Op. + * @param group The group to be lowered. + * @return The lowered funcs. + */ + std::vector LowerCustomCall(const GroupPtr& group); + + /** + * @brief Post processing, including preparing function args and temporary + * variables, applying low-level optimization passes, etc. + * @param group The group to be lowered. + * @param tensor_map All tensors used for calculating the group. + * @param done_op_schedule Mark whether the Op level schedule has been + * applied. + * @param ir_sch The IRSchedule object of group. + * @param group_func_arg_tensors Tensors used as the group function arguments. + * @return The lowered funcs after the post processing. + */ + std::vector PostProcess( + const GroupPtr& group, + const std::unordered_map& tensor_map, + bool done_op_schedule, + ir::IRSchedule* ir_sch, + std::vector* group_func_arg_tensors); + + /** + * @brief Lower an Op set to CINN IR. + * Compute, Lower and optional Schedule will be performed one by one + * for each Op. + * @param nodes The Op nodes to be lowered. + * @param apply_op_schedule Whether to schedule at Op level. + * @param schedule_determine_func Function used to determine which Ops to + * schedule. + * @param group_func_arg_tensors Tensors used as the group function arguments. + * @param tensor_map All tensors used for calculating the group. + * @return The lowered func bodies of Op set. + */ + std::vector LowerOps( + const std::vector& nodes, + bool apply_op_schedule, + ScheduleDetermineFunction schedule_determine_func, + std::vector* group_func_arg_tensors, + std::unordered_map* tensor_map); + + /** + * @brief Lower an Op to CINN IR. The Compute and Lower processes will be + * called sequentially. + * @param op_impl The Op implementation defining Compute and Schedule. + * @param node The Op node to be lowered. + * @param tensor_map All tensors used for calculating the group. + * @param op_func_arg_tensors Tensors used as the Op function arguments. + * @return The lowered func of the Op node. + */ + std::vector DoOpLower( + std::shared_ptr op_impl, + Node* node, + std::unordered_map* tensor_map, + std::vector* op_func_arg_tensors); + + /** + * @brief Apply schedule on an Op. + * @param op_impl The Op implementation defining Compute and Schedule. + * @param op_func_arg_tensors Tensors used as the Op function arguments. + * @param lowered_funcs The lowered funcs of an Op to be scheduled. + * @return The lowered func body after schedule of the Op. + */ + ir::Expr DoOpSchedule(std::shared_ptr op_impl, + const std::vector& op_func_arg_tensors, + const std::vector& lowered_funcs); + + /** + * @brief Apply schedule on a group. + * @param ir_sch The IRSchedule containing the entire group's lowered func + * bodies. + * @param group The group to be scheduled. + * @param tensor_map All tensors used for calculating the group. + * @return The lowered func body after schedule of the group. + */ + ir::Expr DoGroupSchedule( ir::IRSchedule& ir_sch, // NOLINT const GroupPtr& group, const std::unordered_map& tensor_map); + // Functions used to determine which Ops to schedule at op level, define a + // policy for each type of group. + inline bool ReduceScheduleDetermineFunction(Node* node); + inline bool ElementwiseScheduleDetermineFunction(Node* node); + inline bool NonFusibleScheduleDetermineFunction(Node* node); + + private: Target target_; const absl::flat_hash_map& type_dict_; const absl::flat_hash_map& shape_dict_; diff --git a/paddle/cinn/hlir/framework/op_lowering_util.cc b/paddle/cinn/hlir/framework/op_lowering_util.cc index 77443cc86d025b1e94f746f49414584c69fe7601..e7a4412202d87badc3fe0cfc53e5f02fb9c8c074 100644 --- a/paddle/cinn/hlir/framework/op_lowering_util.cc +++ b/paddle/cinn/hlir/framework/op_lowering_util.cc @@ -92,19 +92,19 @@ ir::Tensor GetTensor( std::vector CollectInputTensor( const Node* node, - std::vector& func_args, // NOLINT - std::unordered_map& tensor_map, // NOLINT const absl::flat_hash_map& type_dict, - const absl::flat_hash_map& shape_dict) { + const absl::flat_hash_map& shape_dict, + std::vector* func_args, + std::unordered_map* tensor_map) { std::vector tensors; // get all input nodes for (auto& node_data : GetInputNodeData(node)) { CHECK(node_data); auto tensor = GetTensor(node_data, type_dict, shape_dict); - if (!tensor_map.count(node_data->id())) { - tensor_map[node_data->id()] = tensor; + if (!tensor_map->count(node_data->id())) { + (*tensor_map)[node_data->id()] = tensor; // record func input args - func_args.push_back(tensor); + func_args->push_back(tensor); } tensors.push_back(tensor); } diff --git a/paddle/cinn/hlir/framework/op_lowering_util.h b/paddle/cinn/hlir/framework/op_lowering_util.h index 504ee0600479d584f45ac3e35ebed394a691d133..eb8c21fb5c1d54df2ed25c44886cd8f39f2a5f60 100644 --- a/paddle/cinn/hlir/framework/op_lowering_util.h +++ b/paddle/cinn/hlir/framework/op_lowering_util.h @@ -31,10 +31,10 @@ ir::Tensor GetTensor( std::vector CollectInputTensor( const Node* node, - std::vector& func_args, // NOLINT - std::unordered_map& tensor_map, // NOLINT const absl::flat_hash_map& type_dict, - const absl::flat_hash_map& shape_dict); + const absl::flat_hash_map& shape_dict, + std::vector* func_args, + std::unordered_map* tensor_map); std::unordered_map BuildVirtualConsumer( const GroupPtr& group,