提交 aa70ff8e 编写于 作者: D dingminghui 提交者: jackzhang235

feat(mir): support mlu cast and transpose in MIR

上级 90ed96c3
...@@ -200,7 +200,8 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, ...@@ -200,7 +200,8 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
void MLUPostprocessPass::InsertBefore(SSAGraph* graph, void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
Node* head_node, Node* head_node,
Node* inst_node, Node* inst_node,
const Type* inst_type) { const Type* inst_type,
bool use_mlu_cast) {
const auto* head_type = head_node->AsArg().type; const auto* head_type = head_node->AsArg().type;
// break original link // break original link
...@@ -215,27 +216,30 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph, ...@@ -215,27 +216,30 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
head_node->AsArg().name) != first_conv_nodes_.end(); head_node->AsArg().name) != first_conv_nodes_.end();
// precision cast node // precision cast node
if (head_type->precision() != inst_type->precision() && !is_first_conv_head) { if (!use_mlu_cast) {
cur_node = InsertCastBefore( if (head_type->precision() != inst_type->precision() &&
"cast", !is_first_conv_head) {
name_prefix + "cast", cur_node = InsertCastBefore("cast",
graph, name_prefix + "cast",
cur_node, graph,
inst_node, cur_node,
LiteType::GetTensorTy( inst_node,
head_type->target(), inst_type->precision(), head_type->layout())); LiteType::GetTensorTy(head_type->target(),
} inst_type->precision(),
head_type->layout()));
}
// layout cast node // layout cast node
if (head_type->layout() != inst_type->layout()) { if (head_type->layout() != inst_type->layout()) {
cur_node = InsertCastBefore( cur_node = InsertCastBefore("layout",
"layout", name_prefix + "layout",
name_prefix + "layout", graph,
graph, cur_node,
cur_node, inst_node,
inst_node, LiteType::GetTensorTy(head_type->target(),
LiteType::GetTensorTy( inst_type->precision(),
head_type->target(), inst_type->precision(), inst_type->layout())); inst_type->layout()));
}
} }
// io copy // io copy
...@@ -353,7 +357,8 @@ bool MLUPostprocessPass::NeedInsert(Node* node, const Type* inst_type) { ...@@ -353,7 +357,8 @@ bool MLUPostprocessPass::NeedInsert(Node* node, const Type* inst_type) {
void MLUPostprocessPass::InsertAfter(SSAGraph* graph, void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
Node* tail_node, Node* tail_node,
Node* inst_node, Node* inst_node,
const Type* inst_type) { const Type* inst_type,
bool use_mlu_cast) {
const auto* tail_type = tail_node->AsArg().type; const auto* tail_type = tail_node->AsArg().type;
// break original link // break original link
...@@ -364,27 +369,29 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph, ...@@ -364,27 +369,29 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
tail_node->AsArg().name + string_format("_%p", inst_node) + "/trans_"; tail_node->AsArg().name + string_format("_%p", inst_node) + "/trans_";
// precision cast node // precision cast node
if (tail_type->precision() != inst_type->precision()) { if (!use_mlu_cast) {
cur_node = InsertCastAfter( if (tail_type->precision() != inst_type->precision()) {
"cast", cur_node = InsertCastAfter("cast",
name_prefix + "cast", name_prefix + "cast",
graph, graph,
cur_node, cur_node,
inst_node, inst_node,
LiteType::GetTensorTy( LiteType::GetTensorTy(tail_type->target(),
tail_type->target(), inst_type->precision(), tail_type->layout())); inst_type->precision(),
} tail_type->layout()));
}
// layout cast node // layout cast node
if (tail_type->layout() != inst_type->layout()) { if (tail_type->layout() != inst_type->layout()) {
cur_node = InsertCastAfter( cur_node = InsertCastAfter("layout",
"layout", name_prefix + "layout",
name_prefix + "layout", graph,
graph, cur_node,
cur_node, inst_node,
inst_node, LiteType::GetTensorTy(tail_type->target(),
LiteType::GetTensorTy( inst_type->precision(),
tail_type->target(), inst_type->precision(), inst_type->layout())); inst_type->layout()));
}
} }
// io copy // io copy
...@@ -447,7 +454,7 @@ bool MLUPostprocessPass::IsFirstConvInSubgraph(Node* arg_node, Node* inst) { ...@@ -447,7 +454,7 @@ bool MLUPostprocessPass::IsFirstConvInSubgraph(Node* arg_node, Node* inst) {
auto* block_desc = auto* block_desc =
static_cast<operators::SubgraphOp*>(inst->AsStmt().op().get()) static_cast<operators::SubgraphOp*>(inst->AsStmt().op().get())
->GetSubBlock(); ->GetSubBlock();
for (int op_idx = 0; op_idx < block_desc->OpsSize(); op_idx++) { for (size_t op_idx = 0; op_idx < block_desc->OpsSize(); op_idx++) {
auto op_desc = block_desc->GetOp<cpp::OpDesc>(op_idx); auto op_desc = block_desc->GetOp<cpp::OpDesc>(op_idx);
CHECK(op_desc); CHECK(op_desc);
if (op_desc->Type() == "conv2d") { if (op_desc->Type() == "conv2d") {
...@@ -602,6 +609,149 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { ...@@ -602,6 +609,149 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) {
} }
} }
std::string CheckInputAndInsert(cpp::BlockDesc* block_desc,
const std::string& input_name,
const Type* tensor_type,
const Type* subgraph_type) {
auto cur_node = input_name;
if (DataLayoutCompatible(*tensor_type, *subgraph_type)) {
auto layout_op = block_desc->AddOp<cpp::OpDesc>();
auto layout_arg_name = string_format("%s/layout", cur_node);
layout_op->SetType("layout");
layout_op->SetInput("X", {cur_node});
layout_op->SetOutput("Out", {layout_arg_name});
cur_node = layout_arg_name;
}
if (PrecisionCompatible(*tensor_type, *subgraph_type) &&
tensor_type->precision() != PRECISION(kInt8)) {
auto cast_op = block_desc->AddOp<cpp::OpDesc>();
auto cast_arg_name = string_format("%s/cast", cur_node);
cast_op->SetType("cast");
cast_op->SetAttr<int>("in_dtype", 4); // FP32
cast_op->SetAttr<int>("out_dtype", 5); // FP16
cast_op->SetInput("X", {cur_node});
cast_op->SetOutput("Out", {cast_arg_name});
cur_node = cast_arg_name;
}
return cur_node;
}
std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc,
const std::string& output_name,
const Type* tensor_type,
const Type* subgraph_type) {
auto cur_node = output_name;
cpp::OpDesc *layout_op = nullptr, *cast_op = nullptr;
// subgraph -> cast -> layout -> output
if (PrecisionCompatible(*tensor_type, *subgraph_type)) {
cast_op = block_desc->AddOp<cpp::OpDesc>();
cast_op->SetType("cast");
cast_op->SetAttr<int>("in_dtype", 5); // FP16
cast_op->SetAttr<int>("out_dtype", 4); // FP32
}
if (DataLayoutCompatible(*tensor_type, *subgraph_type)) {
auto layout_arg_name = string_format("%s/layout", cur_node);
layout_op = block_desc->AddOp<cpp::OpDesc>();
layout_op->SetType("layout");
layout_op->SetInput("X", {layout_arg_name});
layout_op->SetOutput("Out", {cur_node});
cur_node = layout_arg_name;
}
if (cast_op) {
auto cast_arg_name = string_format("%s/cast", cur_node);
cast_op->SetInput("X", {cast_arg_name});
cast_op->SetOutput("Out", {cur_node});
cur_node = cast_arg_name;
}
return cur_node;
}
// insert cast op on mlu, to avoid cast on cpu, invoke before first run
void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
const Type* op_type) {
auto subgraph_op = subgraph_node->AsStmt().op();
CHECK(subgraph_op->Type() == "subgraph");
auto op = dynamic_cast<operators::SubgraphOp*>(subgraph_op.get());
CHECK(op);
auto block_desc = op->GetSubBlock();
// create a new block desc to keep op sequence correct
cpp::BlockDesc* new_block_desc = new cpp::BlockDesc();
new_block_desc->ClearOps();
new_block_desc->ClearVars();
new_block_desc->SetIdx(block_desc->Idx());
new_block_desc->SetParentIdx(block_desc->ParentIdx());
new_block_desc->SetForwardBlockIdx(block_desc->ForwardBlockIdx());
// find all IO that is not weight or persist
std::list<std::string> i_names, o_names;
std::map<std::string, std::string> node_replace;
// Insert cast op for iotensor which is not weight or persist
for (auto& input : subgraph_node->inlinks) {
auto input_name = input->AsArg().name;
if (!(input->AsArg().is_weight || input->AsArg().is_persist)) {
i_names.emplace_back(input_name);
node_replace[input_name] = CheckInputAndInsert(
new_block_desc, input_name, input->AsArg().type, op_type);
}
}
for (auto& output : subgraph_node->outlinks) {
auto output_name = output->AsArg().name;
if (!(output->AsArg().is_weight || output->AsArg().is_persist)) {
o_names.emplace_back(output_name);
node_replace[output_name] = CheckOutputAndInsert(
block_desc, output_name, output->AsArg().type, op_type);
}
}
// update input and output
for (size_t op_idx = 0; op_idx < block_desc->OpsSize(); ++op_idx) {
auto desc = block_desc->GetOp<cpp::OpDesc>(op_idx);
auto new_desc = new_block_desc->AddOp<cpp::OpDesc>();
*new_desc = *desc;
auto op_input_args = new_desc->InputArgumentNames();
for (auto& input_arg : op_input_args) {
auto op_input = new_desc->Input(input_arg);
for (auto& it : i_names) {
auto index = std::find(op_input.cbegin(), op_input.cend(), it);
if (index != op_input.cend()) {
index = op_input.erase(index);
op_input.emplace(index, node_replace.at(it));
VLOG(4) << new_desc->Type() << "] change input from " << it << " to "
<< node_replace.at(it);
}
}
new_desc->SetInput(input_arg, op_input);
}
auto op_output_args = new_desc->OutputArgumentNames();
for (auto& output_arg : op_output_args) {
auto op_output = new_desc->Output(output_arg);
for (auto& it : o_names) {
auto index = std::find(op_output.cbegin(), op_output.cend(), it);
if (index != op_output.cend()) {
index = op_output.erase(index);
op_output.emplace(index, node_replace.at(it));
VLOG(4) << new_desc->Type() << "] change output from " << it << " to "
<< node_replace.at(it);
}
}
new_desc->SetOutput(output_arg, op_output);
}
}
op->SetSubBlock(new_block_desc);
// set param to kernel
op->AttachKernel(op->GetKernel());
}
void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// currently for non-persistent input and output args, mlu subgraph op // currently for non-persistent input and output args, mlu subgraph op
// only support float16/float32 data type // only support float16/float32 data type
...@@ -623,22 +773,28 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -623,22 +773,28 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
#endif #endif
g_stream_id = static_cast<int>(reinterpret_cast<int64_t>(graph.get())); g_stream_id = static_cast<int>(reinterpret_cast<int64_t>(graph.get()));
bool use_mlu_cast = GetBoolFromEnv("LITE_MLU_CAST");
// insert io_copy, layout and precision cast of subgraph's inputs and outputs // insert io_copy, layout and precision cast of subgraph's inputs and outputs
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") { if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") {
const Type* subgraph_arg_type = nullptr; const Type* subgraph_arg_type = nullptr;
GetSubgraphOpArgType(&node, &subgraph_arg_type, graph.get()); GetSubgraphOpArgType(&node, &subgraph_arg_type, graph.get());
if (use_mlu_cast) {
AdjustSubgraph(&node, subgraph_arg_type);
}
auto links_tmp = node.inlinks; auto links_tmp = node.inlinks;
for (auto p_in : links_tmp) { for (auto p_in : links_tmp) {
if (NeedInsert(p_in, subgraph_arg_type)) { if (NeedInsert(p_in, subgraph_arg_type)) {
InsertBefore(graph.get(), p_in, &node, subgraph_arg_type); InsertBefore(
graph.get(), p_in, &node, subgraph_arg_type, use_mlu_cast);
} }
} }
links_tmp.assign(node.outlinks.begin(), node.outlinks.end()); links_tmp.assign(node.outlinks.begin(), node.outlinks.end());
for (auto p_out : links_tmp) { for (auto p_out : links_tmp) {
if (NeedInsert(p_out, subgraph_arg_type)) { if (NeedInsert(p_out, subgraph_arg_type)) {
InsertAfter(graph.get(), p_out, &node, subgraph_arg_type); InsertAfter(
graph.get(), p_out, &node, subgraph_arg_type, use_mlu_cast);
} }
} }
} }
......
...@@ -88,12 +88,14 @@ class MLUPostprocessPass : public ProgramPass { ...@@ -88,12 +88,14 @@ class MLUPostprocessPass : public ProgramPass {
void InsertBefore(SSAGraph* graph, void InsertBefore(SSAGraph* graph,
Node* head_node, Node* head_node,
Node* inst_node, Node* inst_node,
const Type* type); const Type* type,
bool use_mlu_cast);
void InsertAfter(SSAGraph* graph, void InsertAfter(SSAGraph* graph,
Node* tail_node, Node* tail_node,
Node* inst_node, Node* inst_node,
const Type* type); const Type* type,
bool use_mlu_cast);
Node* InsertCastBefore(const std::string& op_type, Node* InsertCastBefore(const std::string& op_type,
const std::string& cast_arg_name, const std::string& cast_arg_name,
...@@ -117,6 +119,8 @@ class MLUPostprocessPass : public ProgramPass { ...@@ -117,6 +119,8 @@ class MLUPostprocessPass : public ProgramPass {
bool IsFirstConvInSubgraph(Node* arg_node, Node* inst); bool IsFirstConvInSubgraph(Node* arg_node, Node* inst);
void AdjustSubgraph(Node* subgraph_node, const Type* op_type);
private: private:
std::set<std::string> first_conv_nodes_; std::set<std::string> first_conv_nodes_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册