未验证 提交 3559252a 编写于 作者: B BiynXu 提交者: GitHub

[CINN] comb the op lowering code (#54982)

* [CINN] comb the op lowering code

* [CINN] format code of OpLower
上级 27cc0df5
...@@ -71,7 +71,6 @@ TEST(AutoInline, SingleLoopInline) { ...@@ -71,7 +71,6 @@ TEST(AutoInline, SingleLoopInline) {
nullptr, nullptr,
target, target,
true); true);
VLOG(6) << "Expr after lowering:"; VLOG(6) << "Expr after lowering:";
VLOG(6) << funcs[0]->body; VLOG(6) << funcs[0]->body;
...@@ -170,7 +169,9 @@ TEST(AutoInline, AddReluInline) { ...@@ -170,7 +169,9 @@ TEST(AutoInline, AddReluInline) {
EXPECT_EQ(graph->fusion_groups.size(), 1UL); EXPECT_EQ(graph->fusion_groups.size(), 1UL);
std::vector<ir::LoweredFunc> funcs = std::vector<ir::LoweredFunc> funcs =
op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]); op_lowerer->Lower(graph->fusion_groups[0],
/*apply_op_schedule = */ false,
/*apply_group_schedule=*/false);
VLOG(6) << "Expr before auto inline: " << funcs[0]->body; VLOG(6) << "Expr before auto inline: " << funcs[0]->body;
......
...@@ -388,9 +388,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) { ...@@ -388,9 +388,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) {
TEST_F(TestMultiLevelTiling, Pool2d) { TEST_F(TestMultiLevelTiling, Pool2d) {
default_input_names = {"input"}; default_input_names = {"input"};
default_output_names = {"var_0"}; default_output_names = {"var_0", "pad_temp_0"};
std::vector<int32_t> input_shape{2, 8, 16, 16}; std::vector<std::vector<int32_t>> input_shapes{{2, 8, 16, 16}};
std::vector<int32_t> output_shape{2, 8, 8, 8}; std::vector<std::vector<int32_t>> output_shapes{{2, 8, 8, 8}, {2, 8, 18, 18}};
std::string pooling_type = "max"; std::string pooling_type = "max";
std::vector<int> ksize{3, 3}; std::vector<int> ksize{3, 3};
std::vector<int> strides{2, 2}; std::vector<int> strides{2, 2};
...@@ -402,7 +402,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) { ...@@ -402,7 +402,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
bool adaptive = false; bool adaptive = false;
std::string padding_algorithm = "EXPLICIT"; std::string padding_algorithm = "EXPLICIT";
frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build( frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build(
{{"input", input_shape}}, {{"input", input_shapes[0]}},
{{"pool_type", pooling_type}, {{"pool_type", pooling_type},
{"kernel_size", ksize}, {"kernel_size", ksize},
{"stride_size", strides}, {"stride_size", strides},
...@@ -440,85 +440,82 @@ TEST_F(TestMultiLevelTiling, Pool2d) { ...@@ -440,85 +440,82 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{ {
ScheduleBlock(root) ScheduleBlock(root)
{ {
serial for (i, 0, 2)
{ {
serial for (j, 0, 8) serial for (i, 0, 2)
{ {
serial for (k, 0, 18) serial for (j, 0, 8)
{ {
serial for (a, 0, 18) serial for (k, 0, 18)
{ {
ScheduleBlock(pad_temp_0) serial for (a, 0, 18)
{ {
i0, i1, i2, i3 = axis.bind(i, j, k, a) ScheduleBlock(pad_temp_0)
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f) {
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
}
}
} }
} }
} }
} }
}
}
}
} // end Expr 0
Expr 1 {
{
ScheduleBlock(root_0)
{
{
thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{ {
thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4) thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{ {
serial for (i_1, 0, 1) thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4)
{ {
serial for (j_1, 0, 4) serial for (i_1, 0, 1)
{ {
serial for (k_1, 0, 1) serial for (j_1, 0, 4)
{ {
serial for (a_1, 0, 4) serial for (k_1, 0, 1)
{ {
ScheduleBlock(var_0__reduce_init) serial for (a_1, 0, 4)
{ {
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)) ScheduleBlock(var_0__reduce_init)
{ {
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
}
} }
} }
} }
} }
} }
}
{
serial for (kernel_idx, 0, 3)
{ {
serial for (kernel_idx_0, 0, 3) serial for (kernel_idx, 0, 3)
{ {
serial for (ax0_ax1_ax2_ax3_fused, 0, 28) serial for (kernel_idx_0, 0, 3)
{ {
ScheduleBlock(pad_temp_0_shared_temp_buffer) serial for (ax0_ax1_ax2_ax3_fused, 0, 28)
{ {
v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0))) ScheduleBlock(pad_temp_0_shared_temp_buffer)
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
{ {
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3] v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0)))
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
{
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3]
}
} }
} }
} serial for (i_1, 0, 1)
serial for (i_1, 0, 1)
{
serial for (j_1, 0, 4)
{ {
serial for (k_1, 0, 1) serial for (j_1, 0, 4)
{ {
serial for (a_1, 0, 4) serial for (k_1, 0, 1)
{ {
ScheduleBlock(var_0_local_temp_buffer) serial for (a_1, 0, 4)
{ {
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0) ScheduleBlock(var_0_local_temp_buffer)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
{ {
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))]) i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
}
} }
} }
} }
...@@ -526,21 +523,21 @@ Expr 1 { ...@@ -526,21 +523,21 @@ Expr 1 {
} }
} }
} }
} serial for (ax0_0, 0, 1)
serial for (ax0_0, 0, 1)
{
serial for (ax1_0, 0, 4)
{ {
serial for (ax2_0, 0, 1) serial for (ax1_0, 0, 4)
{ {
serial for (ax3_0, 0, 4) serial for (ax2_0, 0, 1)
{ {
ScheduleBlock(var_0) serial for (ax3_0, 0, 4)
{ {
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0)) ScheduleBlock(var_0)
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{ {
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3] v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
}
} }
} }
} }
...@@ -553,7 +550,7 @@ Expr 1 { ...@@ -553,7 +550,7 @@ Expr 1 {
} }
} }
} }
} // end Expr 1 } // end Expr 0
)ROC"; )ROC";
ASSERT_EQ(ir, expected_ir); ASSERT_EQ(ir, expected_ir);
...@@ -569,8 +566,8 @@ Expr 1 { ...@@ -569,8 +566,8 @@ Expr 1 {
pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))), pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names, default_input_names,
default_output_names, default_output_names,
{input_shape}, input_shapes,
{output_shape}, output_shapes,
target_); target_);
} }
......
...@@ -63,12 +63,10 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule( ...@@ -63,12 +63,10 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape"); absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_); hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_);
if (apply_manual_schedule) { lowered_funcs_ =
lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front()); op_lowerer.Lower(graph->fusion_groups.front(),
} else { /*apply_op_schedule = */ apply_manual_schedule,
lowered_funcs_ = /*apply_group_schedule = */ apply_manual_schedule);
op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front());
}
CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty"; CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty";
std::vector<Expr> bodys; std::vector<Expr> bodys;
......
...@@ -39,7 +39,8 @@ void TuneTask::Initialize( ...@@ -39,7 +39,8 @@ void TuneTask::Initialize(
op_lowerer = lower_handler; op_lowerer = lower_handler;
// Set lowered_funcs and analyze output names. // Set lowered_funcs and analyze output names.
this->lowered_funcs = op_lowerer->LowerWithoutSchedule(subgraph); this->lowered_funcs = op_lowerer->Lower(
subgraph, /*apply_op_schedule = */ false, /*apply_group_schedule=*/false);
this->output_names = GetOutputNamesFromLoweredFunc(this->lowered_funcs); this->output_names = GetOutputNamesFromLoweredFunc(this->lowered_funcs);
this->serialized_key = SerializeToString(shape_dict, dtype_dict); this->serialized_key = SerializeToString(shape_dict, dtype_dict);
} }
......
...@@ -157,7 +157,9 @@ class PerformanceTester : public ::testing::Test { ...@@ -157,7 +157,9 @@ class PerformanceTester : public ::testing::Test {
for (auto group : graph->fusion_groups) { for (auto group : graph->fusion_groups) {
compile_options.lowered_funcs.push_back( compile_options.lowered_funcs.push_back(
op_lowerer->LowerWithoutSchedule(group)); op_lowerer->Lower(group,
/*apply_op_schedule = */ false,
/*apply_group_schedule=*/false));
} }
VLOG(3) << "===========================No Schedule LoweredFunc " VLOG(3) << "===========================No Schedule LoweredFunc "
......
...@@ -39,46 +39,132 @@ using GroupPtr = std::shared_ptr<Graph::Group>; ...@@ -39,46 +39,132 @@ using GroupPtr = std::shared_ptr<Graph::Group>;
using common::Target; using common::Target;
class OpLowerer; class OpLowerer;
typedef std::vector<Expr> (OpLowerer::*IRComputeFunction)(
poly::StageMap&, typedef bool (OpLowerer::*ScheduleDetermineFunction)(Node*);
std::vector<ir::Tensor>&,
std::unordered_map<std::string, ir::Tensor>&,
const GroupPtr&,
const GroupPtr&,
bool);
class OpLowerer { class OpLowerer {
public: public:
OpLowerer(const absl::flat_hash_map<std::string, Type>&, OpLowerer(const absl::flat_hash_map<std::string, Type>&,
const absl::flat_hash_map<std::string, shape_t>&, const absl::flat_hash_map<std::string, shape_t>&,
const Target&); const Target&);
std::vector<ir::LoweredFunc> Lower(GroupPtr& group); // NOLINT
std::vector<ir::LoweredFunc> LowerWithoutSchedule(GroupPtr& group); // NOLINT /**
* @brief Lower a group to CINN IR.
* @param group The group to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param apply_group_schedule Whether to schedule at group level.
* @return The lowered funcs.
*/
std::vector<ir::LoweredFunc> Lower(const GroupPtr& group,
bool apply_op_schedule = true,
bool apply_group_schedule = true);
private: private:
std::vector<ir::LoweredFunc> IRLowerOp(IRComputeFunction, GroupPtr&); /**
std::vector<ir::LoweredFunc> IRLowerNonFusibleOp(GroupPtr&, bool); * @brief Lower a group to CINN IR.
std::vector<ir::LoweredFunc> IRLowerOpWithoutSchedule(IRComputeFunction, * @param group The group to be lowered.
GroupPtr&); * @param apply_op_schedule Whether to schedule at Op level.
#define DEFINE_IR_COMPUTE(type) \ * @param apply_group_schedule Whether to schedule at group level.
std::vector<Expr> IR##type##Compute( \ * @param schedule_determine_func Function used to determine which Ops to
poly::StageMap& stages, \ * schedule.
std::vector<ir::Tensor>& func_args, \ * @return The lowered funcs.
std::unordered_map<std::string, ir::Tensor>& tensor_map, \ */
const GroupPtr& group, \ std::vector<ir::LoweredFunc> LowerGroup(
const GroupPtr& sub_group, \ const GroupPtr& group,
bool apply_impl_schedule = false); bool apply_op_schedule,
bool apply_group_schedule,
// compute and schedule ScheduleDetermineFunction schedule_determine_func);
DEFINE_IR_COMPUTE(Elementwise);
DEFINE_IR_COMPUTE(Reduce); /**
DEFINE_IR_COMPUTE(OutEWiseFusable); * @brief Lower a group composed of CustomCall Op.
* @param group The group to be lowered.
void IRSchedule( * @return The lowered funcs.
*/
std::vector<ir::LoweredFunc> LowerCustomCall(const GroupPtr& group);
/**
* @brief Post processing, including preparing function args and temporary
* variables, applying low-level optimization passes, etc.
* @param group The group to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param done_op_schedule Mark whether the Op level schedule has been
* applied.
* @param ir_sch The IRSchedule object of group.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @return The lowered funcs after the post processing.
*/
std::vector<ir::LoweredFunc> PostProcess(
const GroupPtr& group,
const std::unordered_map<std::string, ir::Tensor>& tensor_map,
bool done_op_schedule,
ir::IRSchedule* ir_sch,
std::vector<ir::Tensor>* group_func_arg_tensors);
/**
* @brief Lower an Op set to CINN IR.
* Compute, Lower and optional Schedule will be performed one by one
* for each Op.
* @param nodes The Op nodes to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param schedule_determine_func Function used to determine which Ops to
* schedule.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func bodies of Op set.
*/
std::vector<ir::Expr> LowerOps(
const std::vector<Node*>& nodes,
bool apply_op_schedule,
ScheduleDetermineFunction schedule_determine_func,
std::vector<ir::Tensor>* group_func_arg_tensors,
std::unordered_map<std::string, ir::Tensor>* tensor_map);
/**
* @brief Lower an Op to CINN IR. The Compute and Lower processes will be
* called sequentially.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param node The Op node to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @return The lowered func of the Op node.
*/
std::vector<ir::LoweredFunc> DoOpLower(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
Node* node,
std::unordered_map<std::string, ir::Tensor>* tensor_map,
std::vector<ir::Tensor>* op_func_arg_tensors);
/**
* @brief Apply schedule on an Op.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @param lowered_funcs The lowered funcs of an Op to be scheduled.
* @return The lowered func body after schedule of the Op.
*/
ir::Expr DoOpSchedule(std::shared_ptr<hlir::framework::OpImpl> op_impl,
const std::vector<ir::Tensor>& op_func_arg_tensors,
const std::vector<ir::LoweredFunc>& lowered_funcs);
/**
* @brief Apply schedule on a group.
* @param ir_sch The IRSchedule containing the entire group's lowered func
* bodies.
* @param group The group to be scheduled.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func body after schedule of the group.
*/
ir::Expr DoGroupSchedule(
ir::IRSchedule& ir_sch, // NOLINT ir::IRSchedule& ir_sch, // NOLINT
const GroupPtr& group, const GroupPtr& group,
const std::unordered_map<std::string, ir::Tensor>& tensor_map); const std::unordered_map<std::string, ir::Tensor>& tensor_map);
// Functions used to determine which Ops to schedule at op level, define a
// policy for each type of group.
inline bool ReduceScheduleDetermineFunction(Node* node);
inline bool ElementwiseScheduleDetermineFunction(Node* node);
inline bool NonFusibleScheduleDetermineFunction(Node* node);
private:
Target target_; Target target_;
const absl::flat_hash_map<std::string, Type>& type_dict_; const absl::flat_hash_map<std::string, Type>& type_dict_;
const absl::flat_hash_map<std::string, shape_t>& shape_dict_; const absl::flat_hash_map<std::string, shape_t>& shape_dict_;
......
...@@ -92,19 +92,19 @@ ir::Tensor GetTensor( ...@@ -92,19 +92,19 @@ ir::Tensor GetTensor(
std::vector<ir::Tensor> CollectInputTensor( std::vector<ir::Tensor> CollectInputTensor(
const Node* node, const Node* node,
std::vector<ir::Tensor>& func_args, // NOLINT
std::unordered_map<std::string, ir::Tensor>& tensor_map, // NOLINT
const absl::flat_hash_map<std::string, Type>& type_dict, const absl::flat_hash_map<std::string, Type>& type_dict,
const absl::flat_hash_map<std::string, shape_t>& shape_dict) { const absl::flat_hash_map<std::string, shape_t>& shape_dict,
std::vector<ir::Tensor>* func_args,
std::unordered_map<std::string, ir::Tensor>* tensor_map) {
std::vector<ir::Tensor> tensors; std::vector<ir::Tensor> tensors;
// get all input nodes // get all input nodes
for (auto& node_data : GetInputNodeData(node)) { for (auto& node_data : GetInputNodeData(node)) {
CHECK(node_data); CHECK(node_data);
auto tensor = GetTensor(node_data, type_dict, shape_dict); auto tensor = GetTensor(node_data, type_dict, shape_dict);
if (!tensor_map.count(node_data->id())) { if (!tensor_map->count(node_data->id())) {
tensor_map[node_data->id()] = tensor; (*tensor_map)[node_data->id()] = tensor;
// record func input args // record func input args
func_args.push_back(tensor); func_args->push_back(tensor);
} }
tensors.push_back(tensor); tensors.push_back(tensor);
} }
......
...@@ -31,10 +31,10 @@ ir::Tensor GetTensor( ...@@ -31,10 +31,10 @@ ir::Tensor GetTensor(
std::vector<ir::Tensor> CollectInputTensor( std::vector<ir::Tensor> CollectInputTensor(
const Node* node, const Node* node,
std::vector<ir::Tensor>& func_args, // NOLINT
std::unordered_map<std::string, ir::Tensor>& tensor_map, // NOLINT
const absl::flat_hash_map<std::string, Type>& type_dict, const absl::flat_hash_map<std::string, Type>& type_dict,
const absl::flat_hash_map<std::string, shape_t>& shape_dict); const absl::flat_hash_map<std::string, shape_t>& shape_dict,
std::vector<ir::Tensor>* func_args,
std::unordered_map<std::string, ir::Tensor>* tensor_map);
std::unordered_map<Node*, Node*> BuildVirtualConsumer( std::unordered_map<Node*, Node*> BuildVirtualConsumer(
const GroupPtr& group, const GroupPtr& group,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册