未验证 提交 3559252a 编写于 作者: B BiynXu 提交者: GitHub

[CINN] comb the op lowering code (#54982)

* [CINN] comb the op lowering code

* [CINN] format code of OpLower
上级 27cc0df5
......@@ -71,7 +71,6 @@ TEST(AutoInline, SingleLoopInline) {
nullptr,
target,
true);
VLOG(6) << "Expr after lowering:";
VLOG(6) << funcs[0]->body;
......@@ -170,7 +169,9 @@ TEST(AutoInline, AddReluInline) {
EXPECT_EQ(graph->fusion_groups.size(), 1UL);
std::vector<ir::LoweredFunc> funcs =
op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]);
op_lowerer->Lower(graph->fusion_groups[0],
/*apply_op_schedule = */ false,
/*apply_group_schedule=*/false);
VLOG(6) << "Expr before auto inline: " << funcs[0]->body;
......
......@@ -388,9 +388,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) {
TEST_F(TestMultiLevelTiling, Pool2d) {
default_input_names = {"input"};
default_output_names = {"var_0"};
std::vector<int32_t> input_shape{2, 8, 16, 16};
std::vector<int32_t> output_shape{2, 8, 8, 8};
default_output_names = {"var_0", "pad_temp_0"};
std::vector<std::vector<int32_t>> input_shapes{{2, 8, 16, 16}};
std::vector<std::vector<int32_t>> output_shapes{{2, 8, 8, 8}, {2, 8, 18, 18}};
std::string pooling_type = "max";
std::vector<int> ksize{3, 3};
std::vector<int> strides{2, 2};
......@@ -402,7 +402,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
bool adaptive = false;
std::string padding_algorithm = "EXPLICIT";
frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build(
{{"input", input_shape}},
{{"input", input_shapes[0]}},
{{"pool_type", pooling_type},
{"kernel_size", ksize},
{"stride_size", strides},
......@@ -439,6 +439,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
std::string expected_ir = R"ROC(Expr 0 {
{
ScheduleBlock(root)
{
{
serial for (i, 0, 2)
{
......@@ -451,6 +452,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
ScheduleBlock(pad_temp_0)
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
}
}
......@@ -458,12 +460,6 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
}
}
}
}
} // end Expr 0
Expr 1 {
{
ScheduleBlock(root_0)
{
{
thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{
......@@ -552,8 +548,9 @@ Expr 1 {
}
}
}
}
}
} // end Expr 1
} // end Expr 0
)ROC";
ASSERT_EQ(ir, expected_ir);
......@@ -569,8 +566,8 @@ Expr 1 {
pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names,
default_output_names,
{input_shape},
{output_shape},
input_shapes,
output_shapes,
target_);
}
......
......@@ -63,12 +63,10 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_);
if (apply_manual_schedule) {
lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front());
} else {
lowered_funcs_ =
op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front());
}
op_lowerer.Lower(graph->fusion_groups.front(),
/*apply_op_schedule = */ apply_manual_schedule,
/*apply_group_schedule = */ apply_manual_schedule);
CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty";
std::vector<Expr> bodys;
......
......@@ -39,7 +39,8 @@ void TuneTask::Initialize(
op_lowerer = lower_handler;
// Set lowered_funcs and analyze output names.
this->lowered_funcs = op_lowerer->LowerWithoutSchedule(subgraph);
this->lowered_funcs = op_lowerer->Lower(
subgraph, /*apply_op_schedule = */ false, /*apply_group_schedule=*/false);
this->output_names = GetOutputNamesFromLoweredFunc(this->lowered_funcs);
this->serialized_key = SerializeToString(shape_dict, dtype_dict);
}
......
......@@ -157,7 +157,9 @@ class PerformanceTester : public ::testing::Test {
for (auto group : graph->fusion_groups) {
compile_options.lowered_funcs.push_back(
op_lowerer->LowerWithoutSchedule(group));
op_lowerer->Lower(group,
/*apply_op_schedule = */ false,
/*apply_group_schedule=*/false));
}
VLOG(3) << "===========================No Schedule LoweredFunc "
......
......@@ -45,7 +45,9 @@ OpLowerer::OpLowerer(
const Target& target)
: type_dict_(type_dict), shape_dict_(shape_dict), target_(target) {}
std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) { // NOLINT
std::vector<ir::LoweredFunc> OpLowerer::Lower(const GroupPtr& group,
bool apply_op_schedule,
bool apply_group_schedule) {
VLOG(3) << "Lowering Group : " << group->group_id
<< " , Op Pattern : " << group->op_pattern_kind;
group->input_names.clear();
......@@ -55,13 +57,22 @@ std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) { // NOLINT
case framework::kElementWise:
case framework::kBroadcast:
case framework::kInjective:
return IRLowerOp(&OpLowerer::IRElementwiseCompute, group);
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::ElementwiseScheduleDetermineFunction);
case framework::kReduction:
return IRLowerOp(&OpLowerer::IRReduceCompute, group);
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::ReduceScheduleDetermineFunction);
case framework::kOutFusible:
LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!";
case framework::kNonFusible:
return IRLowerNonFusibleOp(group, /*apply_impl_schedule = */ true);
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::NonFusibleScheduleDetermineFunction);
default:
LOG(FATAL) << "Group Pattern Kind Is Unknown!";
}
......@@ -70,532 +81,329 @@ std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) { // NOLINT
}
}
std::vector<ir::LoweredFunc> OpLowerer::LowerWithoutSchedule(GroupPtr& group) {
VLOG(3) << "Lowering Group : " << group->group_id
<< " , Op Pattern : " << group->op_pattern_kind;
if (FLAGS_cinn_ir_schedule) {
switch (group->op_pattern_kind) {
case framework::kElementWise:
case framework::kBroadcast:
case framework::kInjective:
return IRLowerOpWithoutSchedule(&OpLowerer::IRElementwiseCompute,
group);
case framework::kReduction:
return IRLowerOpWithoutSchedule(&OpLowerer::IRReduceCompute, group);
case framework::kOutFusible:
LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!";
case framework::kNonFusible:
return IRLowerNonFusibleOp(group, /*apply_impl_schedule = */ false);
default:
LOG(FATAL) << "Group Pattern Kind kNonFusible Is Not Implemented!";
}
} else {
LOG(FATAL) << "Previous IR Schedule Is Not Implemented!";
}
bool OpLowerer::ElementwiseScheduleDetermineFunction(Node* node) {
return true;
}
std::vector<ir::LoweredFunc> OpLowerer::IRLowerOp(IRComputeFunction compute,
GroupPtr& group) {
poly::StageMap stages;
std::vector<ir::Tensor> arg_tensors;
std::unordered_map<std::string, ir::Tensor> tensor_map;
// do compute.
bool OpLowerer::ReduceScheduleDetermineFunction(Node* node) {
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
return op_pattern_dict[node->op()] == framework::kReduction;
}
bool OpLowerer::NonFusibleScheduleDetermineFunction(Node* node) { return true; }
std::vector<ir::LoweredFunc> OpLowerer::LowerGroup(
const GroupPtr& group,
bool apply_op_schedule,
bool apply_group_schedule,
ScheduleDetermineFunction schedule_determine_func) {
// 1.Do compute, lower and schedule for each op.
VLOG(3) << "group->fused_sub_groups.size() is : "
<< group->fused_sub_groups.size();
std::vector<Expr> ast_exprs;
if (group->fused_sub_groups.size() == 0) {
ast_exprs = (this->*compute)(stages,
arg_tensors,
tensor_map,
group,
group,
/*apply_impl_schedule = */ true);
} else {
for (auto& sub_group : group->fused_sub_groups) {
auto exprs = (this->*compute)(stages,
arg_tensors,
tensor_map,
group,
sub_group,
/*apply_impl_schedule = */ true);
ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end());
}
std::vector<Node*> nodes = group->CollectNodes();
if (nodes.size() == 1 && nodes[0]->op()->name == "custom_call") {
return LowerCustomCall(group);
}
ir::ModuleExpr mod_expr(ast_exprs);
std::vector<ir::Tensor> group_func_arg_tensors;
std::unordered_map<std::string, ir::Tensor> tensor_map;
bool do_op_schedule = apply_group_schedule || apply_op_schedule;
std::vector<ir::Expr> func_bodies = LowerOps(nodes,
do_op_schedule,
schedule_determine_func,
&group_func_arg_tensors,
&tensor_map);
// 2.Do group schedule.
ir::ModuleExpr mod_expr(func_bodies);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
Node* first = nullptr;
Node* second = nullptr;
VLOG(3) << "Before IRLowerOp schedule, ir is: \n"
VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
if (apply_group_schedule) {
DoGroupSchedule(ir_sch, group, tensor_map);
VLOG(3) << "After group schedule, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
// do schedule.
IRSchedule(ir_sch, group, tensor_map);
VLOG(3) << "After IRLowerOp schedule, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
// function args
group->input_names.clear();
std::vector<ir::Argument> func_args;
for (auto& args : arg_tensors) {
// input node data name.
group->input_names.push_back(args->name);
// input args
func_args.emplace_back(args->buffer, ir::Argument::IO::kInput);
}
group->output_names.clear();
for (auto& node : group->output_nodes) {
// output node data name.
for (auto node_data : GetAllNodeData(node)) {
group->output_names.push_back(node_data->id());
}
// collect all output tensor.
std::string post = "";
std::string prefix = GetNodeData(node)->id();
for (int idx = 0; idx < 1; ++idx) {
CHECK(tensor_map.count(prefix)) << "Can't find output tensor " << prefix;
if (!tensor_map.count(prefix + post)) {
break;
}
auto tensor = tensor_map[prefix + post];
arg_tensors.push_back(tensor);
// output args
func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput);
// update post
post = "_" + std::to_string(idx);
}
}
auto func_body = ir_sch.GetModule().GetExprs().at(0);
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(func_body));
#endif
auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body);
auto func = ir::_LoweredFunc_::Make(group->GetFuncName(),
func_args,
ir_sch.GetModule().GetExprs().at(0),
temp_buffers);
func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref();
return {func};
// 3.Do post-processing,
// including preparing function args and temporary variables,
// applying low-level optimization passes, etc.
return PostProcess(
group, tensor_map, do_op_schedule, &ir_sch, &group_func_arg_tensors);
}
std::vector<ir::LoweredFunc> OpLowerer::IRLowerOpWithoutSchedule(
IRComputeFunction compute, GroupPtr& group) {
poly::StageMap stages;
std::vector<ir::Tensor> arg_tensors;
std::vector<ir::LoweredFunc> OpLowerer::LowerCustomCall(const GroupPtr& group) {
std::vector<Node*> nodes = group->CollectNodes();
CHECK_EQ(nodes.size(), 1);
Node* node = nodes[0];
std::vector<ir::Tensor> op_func_arg_tensors;
std::unordered_map<std::string, ir::Tensor> tensor_map;
// do compute.
VLOG(3) << "group->fused_sub_groups.size() is : "
<< group->fused_sub_groups.size();
std::vector<Expr> ast_exprs;
if (group->fused_sub_groups.size() == 0) {
ast_exprs = (this->*compute)(stages,
arg_tensors,
tensor_map,
group,
group,
/*apply_impl_schedule = */ false);
for (auto& node_data : GetInputNodeData(node)) {
CHECK(node_data);
ir::Tensor tensor;
if (!tensor_map.count(node_data->id())) {
tensor = GetTensor(node_data, this->type_dict_, this->shape_dict_);
// record tensor.
tensor_map[node_data->id()] = tensor;
// input name.
group->input_names.push_back(node_data->id());
} else {
for (auto& sub_group : group->fused_sub_groups) {
auto exprs = (this->*compute)(stages,
arg_tensors,
tensor_map,
group,
sub_group,
/*apply_impl_schedule = */ false);
ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end());
tensor = tensor_map[node_data->id()];
}
op_func_arg_tensors.push_back(tensor);
}
ir::ModuleExpr mod_expr(ast_exprs);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
VLOG(3) << "After IRLowerOp compute, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
// function args
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
auto node_datas = GetAllNodeData(node);
for (auto node_data : node_datas) {
group->output_names.push_back(node_data->id());
out_types.push_back(this->type_dict_.at(node_data->id()));
out_shapes.push_back(this->shape_dict_.at(node_data->id()));
}
auto& cinn_strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()](
node->attrs, op_func_arg_tensors, out_types, out_shapes, target_));
std::string external_api;
if (node->attrs.attr_store.count("custom_call")) {
external_api =
absl::get<std::string>(node->attrs.attr_store.at("custom_call"));
} else {
external_api = ExternalApiRegistry::Global()->GetExternalApi(node, target_);
}
std::vector<common::CINNValue> compute_args = {
common::CINNValue(group->GetFuncName()), common::CINNValue(external_api)};
common::CINNValuePack pack =
impl->fcompute(common::CINNValuePack{compute_args});
CHECK_EQ(pack.size(), 1UL);
// reset input names as extern api input args can't be remove duplicate.
group->input_names.clear();
std::vector<ir::Argument> func_args;
for (auto& args : arg_tensors) {
for (auto& inode : node->inlinks_in_order()) {
group->input_names.push_back(inode->source()->as<NodeData>()->id());
}
return {pack[0].operator ir::Expr().as_lowered_func_ref()};
}
std::vector<ir::LoweredFunc> OpLowerer::PostProcess(
const GroupPtr& group,
const std::unordered_map<std::string, ir::Tensor>& tensor_map,
bool done_op_schedule,
ir::IRSchedule* ir_sch,
std::vector<ir::Tensor>* group_func_arg_tensors) {
// 1.Prepare function args
group->input_names.clear();
std::vector<ir::Argument> group_func_args;
std::unordered_set<std::string> arg_name_set;
for (auto& arg_tensor : *group_func_arg_tensors) {
// input node data name.
group->input_names.push_back(args->name);
group->input_names.push_back(arg_tensor->name);
// input args
func_args.emplace_back(args->buffer, ir::Argument::IO::kInput);
group_func_args.emplace_back(arg_tensor->buffer, ir::Argument::IO::kInput);
arg_name_set.insert(arg_tensor->buffer->name);
}
group->output_names.clear();
for (auto& node : group->output_nodes) {
// output node data name.
// collect all output tensor.
for (auto node_data : GetAllNodeData(node)) {
group->output_names.push_back(node_data->id());
std::string output_node_data_name = node_data->id();
group->output_names.push_back(output_node_data_name);
// CHECK(tensor_map.count(output_node_data_name)) << "Can't find output
// tensor " << output_node_data_name;
if (tensor_map.count(output_node_data_name) == 0) {
continue;
}
// collect all output tensor.
std::string post = "";
std::string prefix = GetNodeData(node)->id();
for (int idx = 0; idx < 1; ++idx) {
CHECK(tensor_map.count(prefix)) << "Can't find output tensor " << prefix;
if (!tensor_map.count(prefix + post)) {
break;
}
auto tensor = tensor_map[prefix + post];
arg_tensors.push_back(tensor);
auto tensor = tensor_map.at(output_node_data_name);
if (arg_name_set.count(tensor->buffer->name) != 0) {
continue;
}
// output arg tensors
group_func_arg_tensors->push_back(tensor);
// output args
func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput);
// update post
post = "_" + std::to_string(idx);
group_func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput);
arg_name_set.insert(tensor->buffer->name);
}
}
std::unordered_set<std::string> args_map;
for (auto arg : func_args) {
args_map.insert(arg.name());
if (!done_op_schedule) {
std::unordered_set<std::string> args_set;
for (auto arg : group_func_args) {
args_set.insert(arg.name());
}
for (auto& tensor : tensor_map) {
if (args_map.count("_" + tensor.first)) {
for (auto& tensor_pair : tensor_map) {
if (args_set.count("_" + tensor_pair.second->name)) {
continue;
}
arg_tensors.push_back(tensor.second);
// use the underlying tensor name to be consistent with the argument name in
// the lowered function
group->output_names.push_back(tensor.second->name);
func_args.emplace_back(tensor.second->buffer, ir::Argument::IO::kOutput);
group_func_arg_tensors->push_back(tensor_pair.second);
// use the underlying tensor name to be consistent with the argument name
// in the lowered function
group->output_names.push_back(tensor_pair.second->name);
group_func_args.emplace_back(tensor_pair.second->buffer,
ir::Argument::IO::kOutput);
}
}
auto func_body = ir_sch.GetModule().GetExprs().at(0);
auto func_body = ir_sch->GetModule().GetExprs().at(0);
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(func_body));
#endif
auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body);
// 2.Prepare temp buffers
poly::StageMap stages;
auto temp_buffers =
lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body);
// 3.Building LoweredFunc
auto func = ir::_LoweredFunc_::Make(group->GetFuncName(),
func_args,
ir_sch.GetModule().GetExprs().at(0),
group_func_args,
ir_sch->GetModule().GetExprs().at(0),
temp_buffers);
if (!done_op_schedule) {
func->PrepareBufferCastExprs();
}
// 4.Apply low level pass
func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref();
return {func};
}
std::vector<Expr> OpLowerer::IRElementwiseCompute(
poly::StageMap& stages,
std::vector<ir::Tensor>& func_tensors,
std::unordered_map<std::string, ir::Tensor>& tensor_map,
const GroupPtr& group,
const GroupPtr& sub_group,
bool apply_impl_schedule) {
VLOG(2) << "ElementwiseCompute Group : " << sub_group->group_id;
std::vector<ir::Expr> OpLowerer::LowerOps(
const std::vector<Node*>& nodes,
bool apply_op_schedule,
ScheduleDetermineFunction schedule_determine_func,
std::vector<ir::Tensor>* group_func_arg_tensors,
std::unordered_map<std::string, ir::Tensor>* tensor_map) {
auto& strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
std::vector<Expr> ast_exprs;
for (auto& node : sub_group->nodes) {
VLOG(4) << "Lower op: " << node->op()->name;
auto node_data = GetNodeData(node);
CHECK_EQ(GetAllNodeData(node).size(), 1U);
std::vector<common::CINNValue> cinn_inputs;
std::vector<ir::Tensor> tensor_inputs = std::move(CollectInputTensor(
node, func_tensors, tensor_map, this->type_dict_, this->shape_dict_));
for (auto& tensor : tensor_inputs) {
cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor)));
}
// set tensor name = node data name
cinn_inputs.push_back(common::CINNValue(node_data->id()));
std::vector<Expr> func_bodies;
for (Node* node : nodes) {
// 1.Select Op impl
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
std::vector<NodeData*> node_datas = GetAllNodeData(node);
for (const auto& node_data : node_datas) {
out_types.push_back(this->type_dict_.at(node_data->id()));
out_shapes.push_back(this->shape_dict_.at(node_data->id()));
auto impl = OpStrategy::SelectImpl(strategy[node->op()](
node->attrs, tensor_inputs, out_types, out_shapes, this->target_));
// do compute
common::CINNValuePack pack =
impl->fcompute(common::CINNValuePack{cinn_inputs});
CHECK_EQ(pack.size(), 2U);
Expr expr = pack[0];
poly::StageMap node_stages = pack.back();
tensor_inputs.push_back(expr.as_tensor_ref());
tensor_map[node_data->id()] = expr.as_tensor_ref();
auto func = lang::LowerVec("fn_" + node->id(),
node_stages,
tensor_inputs,
{},
{},
nullptr,
this->target_,
true);
CHECK_EQ(func.size(), 1);
if (apply_impl_schedule) {
std::vector<common::CINNValue> schedule_inputs;
// collect tensor
for (int idx = 0; idx < pack.size() - 1; ++idx) {
CHECK(pack[idx].is_tensor());
schedule_inputs.push_back(common::CINNValue(pack[idx]));
}
for (auto& f : func) {
schedule_inputs.push_back(common::CINNValue(f->body));
}
// do ast tree schedule
common::CINNValuePack expr_pack =
impl->fschedule(common::CINNValuePack{schedule_inputs});
CHECK_EQ(expr_pack.size(), 1);
Expr ast_expr = expr_pack[0];
ast_exprs.push_back(ast_expr);
std::vector<ir::Tensor> op_func_arg_tensors =
std::move(CollectInputTensor(node,
this->type_dict_,
this->shape_dict_,
group_func_arg_tensors,
tensor_map));
auto op_impl =
OpStrategy::SelectImpl(strategy[node->op()](node->attrs,
op_func_arg_tensors,
out_types,
out_shapes,
this->target_));
// 2.Perform the lower process of Op
std::vector<ir::LoweredFunc> funcs =
DoOpLower(op_impl, node, tensor_map, &op_func_arg_tensors);
if (apply_op_schedule && (this->*schedule_determine_func)(node)) {
// 3.Perform the schedule of Op
func_bodies.push_back(DoOpSchedule(op_impl, op_func_arg_tensors, funcs));
} else {
ast_exprs.push_back(func[0]->body);
for (const ir::LoweredFunc& func : funcs) {
func_bodies.push_back(func->body);
}
}
}
return ast_exprs;
return func_bodies;
}
std::vector<Expr> OpLowerer::IRReduceCompute(
poly::StageMap& stages,
std::vector<ir::Tensor>& func_args,
std::unordered_map<std::string, ir::Tensor>& tensor_map,
const GroupPtr& group,
const GroupPtr& sub_group,
bool apply_impl_schedule) {
VLOG(2) << "ReduceCompute Group : " << sub_group->group_id;
auto& cinn_strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
std::vector<Expr> ast_exprs;
for (auto& node : sub_group->nodes) {
auto node_data = GetNodeData(node);
VLOG(3) << "In ReduceCompute, process node: " << node->id()
<< " with op type: " << node->op()->name;
std::vector<ir::LoweredFunc> OpLowerer::DoOpLower(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
Node* node,
std::unordered_map<std::string, ir::Tensor>* tensor_map,
std::vector<ir::Tensor>* op_func_arg_tensors) {
VLOG(4) << "Do lower with Compute, op: " << node->op()->name;
std::vector<common::CINNValue> cinn_inputs;
std::vector<ir::Tensor> tensor_inputs = std::move(CollectInputTensor(
node, func_args, tensor_map, this->type_dict_, this->shape_dict_));
for (auto& tensor : tensor_inputs) {
for (const ir::Tensor& tensor : *op_func_arg_tensors) {
cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor)));
}
// set tensor name = node data name
std::vector<NodeData*> node_datas = GetAllNodeData(node);
for (const NodeData* node_data : node_datas) {
cinn_inputs.push_back(common::CINNValue(node_data->id()));
}
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
out_types.push_back(this->type_dict_.at(node_data->id()));
out_shapes.push_back(this->shape_dict_.at(node_data->id()));
auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()](
node->attrs, tensor_inputs, out_types, out_shapes, target_));
// do compute
// 1.Do compute
common::CINNValuePack pack =
impl->fcompute(common::CINNValuePack{cinn_inputs});
op_impl->fcompute(common::CINNValuePack{cinn_inputs});
CHECK_GE(pack.size(), 2UL);
CHECK_LE(pack.size(), 5UL);
poly::StageMap tmp_stages = pack.back();
std::string post = "";
for (int idx = 0; idx < pack.size() - 1; ++idx) {
Expr expr = pack[idx];
tensor_map[node_data->id() + post] = expr.as_tensor_ref();
// As op may has more than 1 output tensor, using id + "_0"/"_1" as key.
// Insert the output tensor defined by Compute into the tensor_map
if (pack.size() - 1 > node_datas.size()) {
// Some nodes may output multiple temp tensors in their Compute
// definition, but only one output node_data in the graph, and we use id +
// "_0"/"_1" as key.
(*tensor_map)[node_datas[0]->id() + post] = expr.as_tensor_ref();
post = "_" + std::to_string(idx);
} else {
// If the number of output tensors defined by Compute is less equal than
// the output node_data on the graph, then there is a one-to-one
// correspondence, and the redundant output node_data contact empty.
(*tensor_map)[node_datas[idx]->id()] = expr.as_tensor_ref();
}
// Insert outout tensors
// Insert output tensors into function arg
if (!expr.as_tensor_ref()->buffer.defined() ||
this->target_ != common::DefaultNVGPUTarget()) {
tensor_inputs.push_back(expr.as_tensor_ref());
op_func_arg_tensors->push_back(expr.as_tensor_ref());
expr.as_tensor_ref()->WithBuffer();
}
}
auto func = lang::LowerVec("fn_" + node->id(),
// 2.Do lower
std::vector<ir::LoweredFunc> funcs = lang::LowerVec("fn_" + node->id(),
tmp_stages,
tensor_inputs,
*op_func_arg_tensors,
{},
{},
nullptr,
this->target_,
true);
VLOG(4) << "Lower op: " << node->op()->name << ", get " << funcs.size()
<< " LoweredFunc:\n";
// node is kReduction
if (op_pattern_dict[node->op()] == framework::kReduction &&
apply_impl_schedule) {
std::vector<common::CINNValue> schedule_inputs;
// collect tensor
op_func_arg_tensors->clear();
for (int idx = 0; idx < pack.size() - 1; ++idx) {
CHECK(pack[idx].is_tensor());
schedule_inputs.push_back(common::CINNValue(pack[idx]));
}
for (auto& f : func) {
schedule_inputs.push_back(common::CINNValue(f->body));
}
// do ast tree schedule
common::CINNValuePack expr_pack =
impl->fschedule(common::CINNValuePack{schedule_inputs});
// ast tree after schedule.
Expr ast_expr = expr_pack[0];
ast_exprs.push_back(ast_expr);
} else if (group->master_nodes.count(node)) {
// as master node should copy transform from reducer, left it to reduce
// schedule.
ast_exprs.push_back(func[0]->body);
} else {
ast_exprs.push_back(func[0]->body);
}
op_func_arg_tensors->push_back(
pack[idx].operator ir::Expr().as_tensor_ref());
}
return ast_exprs;
return funcs;
}
std::vector<ir::LoweredFunc> OpLowerer::IRLowerNonFusibleOp(
GroupPtr& group, bool apply_impl_schedule) {
VLOG(3) << "LowerNonFusibleOp Group : " << group->group_id;
// get input tensor and output tensor
CHECK(group->nodes.size() || group->fused_sub_groups.size());
auto& cinn_strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
auto node = group->fused_sub_groups.size()
? group->fused_sub_groups[0]->nodes.front()
: group->nodes.front();
VLOG(3) << "GetOpFunc of op " << node->id();
std::vector<ir::Tensor> inputs;
std::vector<common::CINNValue> cinn_inputs;
std::vector<ir::Argument> args;
std::unordered_map<std::string, ir::Tensor> tensor_map;
for (auto& node_data : GetInputNodeData(node)) {
CHECK(node_data);
ir::Tensor tensor;
if (!tensor_map.count(node_data->id())) {
tensor = GetTensor(node_data, this->type_dict_, this->shape_dict_);
// record tensor.
tensor_map[node_data->id()] = tensor;
// input name.
group->input_names.push_back(node_data->id());
// input type.
args.emplace_back(tensor->buffer, ir::Argument::IO::kInput);
} else {
tensor = tensor_map[node_data->id()];
}
inputs.push_back(tensor);
cinn_inputs.push_back(common::CINNValue(tensor));
}
std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
auto node_datas = GetAllNodeData(node);
for (auto node_data : node_datas) {
VLOG(3) << "cinn_inputs.push_back " << node_data->id();
group->output_names.push_back(node_data->id());
out_types.push_back(this->type_dict_.at(node_data->id()));
out_shapes.push_back(this->shape_dict_.at(node_data->id()));
cinn_inputs.push_back(common::CINNValue(node_data->id()));
}
auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()](
node->attrs, inputs, out_types, out_shapes, target_));
// if node op is custom_call, apply custom_call compute.
if (node->op()->name == "custom_call") {
std::string external_api;
if (node->attrs.attr_store.count("custom_call")) {
external_api =
absl::get<std::string>(node->attrs.attr_store.at("custom_call"));
} else {
external_api =
ExternalApiRegistry::Global()->GetExternalApi(node, target_);
}
std::vector<common::CINNValue> compute_args = {
common::CINNValue(group->GetFuncName()),
common::CINNValue(external_api)};
common::CINNValuePack pack =
impl->fcompute(common::CINNValuePack{compute_args});
CHECK_EQ(pack.size(), 1UL);
// reset input names as extern api input args can't be remove duplicate.
group->input_names.clear();
for (auto& inode : node->inlinks_in_order()) {
group->input_names.push_back(inode->source()->as<NodeData>()->id());
}
return {pack[0].operator ir::Expr().as_lowered_func_ref()};
}
common::CINNValuePack pack =
impl->fcompute(common::CINNValuePack{cinn_inputs});
for (int i = 0; i < pack->size() - 1; i++) {
ir::Expr temp = pack[i];
// checkout whether the tensor is with buffer.
if (!temp.as_tensor_ref()->buffer.defined() ||
this->target_ != common::DefaultNVGPUTarget()) {
inputs.push_back(temp.as_tensor_ref());
temp.as_tensor_ref()->WithBuffer();
args.emplace_back(temp.as_tensor_ref()->buffer,
ir::Argument::IO::kOutput);
}
}
poly::StageMap stages = pack.back();
auto func = lang::LowerVec(group->GetFuncName(),
stages,
inputs,
{},
{},
nullptr,
this->target_,
true);
if (apply_impl_schedule) {
ir::Expr OpLowerer::DoOpSchedule(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
const std::vector<ir::Tensor>& op_func_arg_tensors,
const std::vector<ir::LoweredFunc>& lowered_funcs) {
VLOG(4) << "Do op schedule";
std::vector<common::CINNValue> schedule_inputs;
// collect tensor
for (int idx = 0; idx < pack.size() - 1; ++idx) {
CHECK(pack[idx].is_tensor());
schedule_inputs.push_back(common::CINNValue(pack[idx]));
// 1.Collect tensors
for (const ir::Tensor& op_func_arg_tensor : op_func_arg_tensors) {
schedule_inputs.push_back(common::CINNValue(op_func_arg_tensor));
}
for (auto& f : func) {
schedule_inputs.push_back(common::CINNValue(f->body));
// 2.Collect bodies to be scheduled
for (const ir::LoweredFunc& func : lowered_funcs) {
schedule_inputs.push_back(common::CINNValue(func->body));
}
// do ast tree schedule
// 3.Do schedule on AST
common::CINNValuePack expr_pack =
impl->fschedule(common::CINNValuePack{schedule_inputs});
ir::Expr func_body = expr_pack[0];
std::vector<std::string> input_output_nodes(group->input_names);
input_output_nodes.insert(input_output_nodes.end(),
group->output_names.begin(),
group->output_names.end());
VLOG(6) << "func.size() = " << func.size()
<< ", expr_pack.size() = " << expr_pack.size();
VLOG(6) << "args.size() = " << args.size()
<< ", input_output_nodes.size() = " << input_output_nodes.size();
if (args.size() > input_output_nodes.size()) {
args = lang::GetArgs(func_body, input_output_nodes);
}
std::vector<ir::LoweredFunc> res;
for (int i = 0; i < expr_pack.size(); i++) {
ir::Expr func_body = expr_pack[0];
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(func_body));
#endif
auto temp_buffers = lang::GetTempBuffers(inputs, stages, func_body);
auto function = ir::_LoweredFunc_::Make(
group->GetFuncName(), args, func_body, temp_buffers);
res.push_back(function);
}
for (auto& i : res) {
i = optim::Optimize(Expr(i), target_, false).as_lowered_func_ref();
}
return res;
} else {
for (auto& f : func) {
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(f->body));
#endif
f = optim::Optimize(Expr(f), target_, false).as_lowered_func_ref();
}
return func;
}
op_impl->fschedule(common::CINNValuePack{schedule_inputs});
VLOG(4) << "After op schedule: " << expr_pack[0].operator ir::Expr();
return expr_pack[0].operator ir::Expr();
}
// group schedule
void OpLowerer::IRSchedule(
ir::Expr OpLowerer::DoGroupSchedule(
ir::IRSchedule& ir_sch,
const GroupPtr& group,
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
......@@ -698,6 +506,7 @@ void OpLowerer::IRSchedule(
<< ", ir is:\n"
<< ir_sch.GetModule().GetExprs().at(0);
// if node is horizontal with reduce or node is reduce, loop assign
//
// master.
auto loops = ir_sch.GetLoops(GetNodeData(node)->id());
if (op_pattern_dict[node->op()] == framework::kElementWise) {
......@@ -788,6 +597,7 @@ void OpLowerer::IRSchedule(
ir_sch, group, nodes_inline, nodes_set, this->shape_dict_, tensor_map);
VLOG(4) << "After IRSchedule, ir is: \n"
<< ir_sch.GetModule().GetExprs().at(0);
return ir_sch.GetModule().GetExprs().at(0);
}
} // namespace framework
......
......@@ -39,46 +39,132 @@ using GroupPtr = std::shared_ptr<Graph::Group>;
using common::Target;
class OpLowerer;
typedef std::vector<Expr> (OpLowerer::*IRComputeFunction)(
poly::StageMap&,
std::vector<ir::Tensor>&,
std::unordered_map<std::string, ir::Tensor>&,
const GroupPtr&,
const GroupPtr&,
bool);
typedef bool (OpLowerer::*ScheduleDetermineFunction)(Node*);
class OpLowerer {
public:
OpLowerer(const absl::flat_hash_map<std::string, Type>&,
const absl::flat_hash_map<std::string, shape_t>&,
const Target&);
std::vector<ir::LoweredFunc> Lower(GroupPtr& group); // NOLINT
std::vector<ir::LoweredFunc> LowerWithoutSchedule(GroupPtr& group); // NOLINT
/**
* @brief Lower a group to CINN IR.
* @param group The group to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param apply_group_schedule Whether to schedule at group level.
* @return The lowered funcs.
*/
std::vector<ir::LoweredFunc> Lower(const GroupPtr& group,
bool apply_op_schedule = true,
bool apply_group_schedule = true);
private:
std::vector<ir::LoweredFunc> IRLowerOp(IRComputeFunction, GroupPtr&);
std::vector<ir::LoweredFunc> IRLowerNonFusibleOp(GroupPtr&, bool);
std::vector<ir::LoweredFunc> IRLowerOpWithoutSchedule(IRComputeFunction,
GroupPtr&);
#define DEFINE_IR_COMPUTE(type) \
std::vector<Expr> IR##type##Compute( \
poly::StageMap& stages, \
std::vector<ir::Tensor>& func_args, \
std::unordered_map<std::string, ir::Tensor>& tensor_map, \
const GroupPtr& group, \
const GroupPtr& sub_group, \
bool apply_impl_schedule = false);
// compute and schedule
DEFINE_IR_COMPUTE(Elementwise);
DEFINE_IR_COMPUTE(Reduce);
DEFINE_IR_COMPUTE(OutEWiseFusable);
void IRSchedule(
/**
* @brief Lower a group to CINN IR.
* @param group The group to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param apply_group_schedule Whether to schedule at group level.
* @param schedule_determine_func Function used to determine which Ops to
* schedule.
* @return The lowered funcs.
*/
std::vector<ir::LoweredFunc> LowerGroup(
const GroupPtr& group,
bool apply_op_schedule,
bool apply_group_schedule,
ScheduleDetermineFunction schedule_determine_func);
/**
* @brief Lower a group composed of CustomCall Op.
* @param group The group to be lowered.
* @return The lowered funcs.
*/
std::vector<ir::LoweredFunc> LowerCustomCall(const GroupPtr& group);
/**
* @brief Post processing, including preparing function args and temporary
* variables, applying low-level optimization passes, etc.
* @param group The group to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param done_op_schedule Mark whether the Op level schedule has been
* applied.
* @param ir_sch The IRSchedule object of group.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @return The lowered funcs after the post processing.
*/
std::vector<ir::LoweredFunc> PostProcess(
const GroupPtr& group,
const std::unordered_map<std::string, ir::Tensor>& tensor_map,
bool done_op_schedule,
ir::IRSchedule* ir_sch,
std::vector<ir::Tensor>* group_func_arg_tensors);
/**
* @brief Lower an Op set to CINN IR.
* Compute, Lower and optional Schedule will be performed one by one
* for each Op.
* @param nodes The Op nodes to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param schedule_determine_func Function used to determine which Ops to
* schedule.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func bodies of Op set.
*/
std::vector<ir::Expr> LowerOps(
const std::vector<Node*>& nodes,
bool apply_op_schedule,
ScheduleDetermineFunction schedule_determine_func,
std::vector<ir::Tensor>* group_func_arg_tensors,
std::unordered_map<std::string, ir::Tensor>* tensor_map);
/**
* @brief Lower an Op to CINN IR. The Compute and Lower processes will be
* called sequentially.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param node The Op node to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @return The lowered func of the Op node.
*/
std::vector<ir::LoweredFunc> DoOpLower(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
Node* node,
std::unordered_map<std::string, ir::Tensor>* tensor_map,
std::vector<ir::Tensor>* op_func_arg_tensors);
/**
* @brief Apply schedule on an Op.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @param lowered_funcs The lowered funcs of an Op to be scheduled.
* @return The lowered func body after schedule of the Op.
*/
ir::Expr DoOpSchedule(std::shared_ptr<hlir::framework::OpImpl> op_impl,
const std::vector<ir::Tensor>& op_func_arg_tensors,
const std::vector<ir::LoweredFunc>& lowered_funcs);
/**
* @brief Apply schedule on a group.
* @param ir_sch The IRSchedule containing the entire group's lowered func
* bodies.
* @param group The group to be scheduled.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func body after schedule of the group.
*/
ir::Expr DoGroupSchedule(
ir::IRSchedule& ir_sch, // NOLINT
const GroupPtr& group,
const std::unordered_map<std::string, ir::Tensor>& tensor_map);
// Functions used to determine which Ops to schedule at op level, define a
// policy for each type of group.
inline bool ReduceScheduleDetermineFunction(Node* node);
inline bool ElementwiseScheduleDetermineFunction(Node* node);
inline bool NonFusibleScheduleDetermineFunction(Node* node);
private:
Target target_;
const absl::flat_hash_map<std::string, Type>& type_dict_;
const absl::flat_hash_map<std::string, shape_t>& shape_dict_;
......
......@@ -92,19 +92,19 @@ ir::Tensor GetTensor(
std::vector<ir::Tensor> CollectInputTensor(
const Node* node,
std::vector<ir::Tensor>& func_args, // NOLINT
std::unordered_map<std::string, ir::Tensor>& tensor_map, // NOLINT
const absl::flat_hash_map<std::string, Type>& type_dict,
const absl::flat_hash_map<std::string, shape_t>& shape_dict) {
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
std::vector<ir::Tensor>* func_args,
std::unordered_map<std::string, ir::Tensor>* tensor_map) {
std::vector<ir::Tensor> tensors;
// get all input nodes
for (auto& node_data : GetInputNodeData(node)) {
CHECK(node_data);
auto tensor = GetTensor(node_data, type_dict, shape_dict);
if (!tensor_map.count(node_data->id())) {
tensor_map[node_data->id()] = tensor;
if (!tensor_map->count(node_data->id())) {
(*tensor_map)[node_data->id()] = tensor;
// record func input args
func_args.push_back(tensor);
func_args->push_back(tensor);
}
tensors.push_back(tensor);
}
......
......@@ -31,10 +31,10 @@ ir::Tensor GetTensor(
std::vector<ir::Tensor> CollectInputTensor(
const Node* node,
std::vector<ir::Tensor>& func_args, // NOLINT
std::unordered_map<std::string, ir::Tensor>& tensor_map, // NOLINT
const absl::flat_hash_map<std::string, Type>& type_dict,
const absl::flat_hash_map<std::string, shape_t>& shape_dict);
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
std::vector<ir::Tensor>* func_args,
std::unordered_map<std::string, ir::Tensor>* tensor_map);
std::unordered_map<Node*, Node*> BuildVirtualConsumer(
const GroupPtr& group,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册