提交 5f12addc 编写于 作者: D dingminghui 提交者: jackzhang235

fix(mir): fix mlu cast and layout error insert

上级 56c1f666
......@@ -134,8 +134,8 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
cpp::OpDesc op_desc;
op_desc.SetType(op_type);
if (op_type == "cast") {
op_desc.SetAttr<int>("in_dtype", 4); // FP32
op_desc.SetAttr<int>("out_dtype", 5); // FP16
op_desc.SetAttr<int>("in_dtype", 4); // FP16
op_desc.SetAttr<int>("out_dtype", 5); // FP32
op_desc.SetInput("X", {cast_arg_name});
op_desc.SetOutput("Out", {cur_node->AsArg().name});
} else if (op_type == "layout") {
......@@ -241,17 +241,27 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
inst_type->precision(),
inst_type->layout()));
}
}
// io copy
cur_node = InsertCastBefore(
"io_copy",
name_prefix + "io_copy",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(
inst_type->target(), inst_type->precision(), inst_type->layout()));
// io copy
cur_node = InsertCastBefore(
"io_copy",
name_prefix + "io_copy",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(
inst_type->target(), inst_type->precision(), inst_type->layout()));
} else {
// io copy
cur_node = InsertCastBefore(
"io_copy",
name_prefix + "io_copy",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(
inst_type->target(), head_type->precision(), head_type->layout()));
}
// connect cur_node to inst_node
DirectedLink(cur_node, inst_node);
......@@ -393,17 +403,26 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
inst_type->precision(),
inst_type->layout()));
}
}
// io copy
cur_node = InsertCastAfter(
"io_copy",
name_prefix + "io_copy",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(
inst_type->target(), inst_type->precision(), inst_type->layout()));
// io copy
cur_node = InsertCastAfter(
"io_copy",
name_prefix + "io_copy",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(
inst_type->target(), inst_type->precision(), inst_type->layout()));
} else {
cur_node = InsertCastAfter(
"io_copy",
name_prefix + "io_copy",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(
inst_type->target(), tail_type->precision(), tail_type->layout()));
}
// connect cur_node to inst_node
DirectedLink(inst_node, cur_node);
......@@ -504,6 +523,8 @@ void MLUPostprocessPass::GatherAndModifyFirstConvNodes(SSAGraph* graph) {
void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") {
const Type* subgraph_arg_type = nullptr;
GetSubgraphOpArgType(&node, &subgraph_arg_type, graph);
for (auto& in_node : node.inlinks) {
const auto* in_node_type = in_node->AsArg().type;
VLOG(4) << "MLU subgraph input type: " << in_node->AsArg().name
......@@ -525,6 +546,7 @@ void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) {
}
for (auto& out_node : node.outlinks) {
const auto* out_node_type = out_node->AsArg().type;
auto& out_arg = out_node->AsArg();
VLOG(4) << "MLU subgraph output type: " << out_node->AsArg().name
<< *out_node_type;
if (out_node->AsArg().is_weight || out_node->AsArg().is_persist) {
......@@ -536,14 +558,24 @@ void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) {
TARGET(kMLU), PRECISION(kAny), DATALAYOUT(kNHWC));
} else if (out_node_type->precision() == PRECISION(kAny) &&
out_node->outlinks.empty()) {
out_node->AsArg().is_persist = true;
out_node->AsArg().type = LiteType::GetTensorTy(
out_arg.is_persist = true;
out_arg.type = LiteType::GetTensorTy(
TARGET(kMLU), PRECISION(kAny), DATALAYOUT(kNHWC));
} else {
CHECK(out_node_type->precision() == PRECISION(kFloat))
<< "MLU subgraph unexpected common output type!";
out_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
if (out_node->outlinks.empty()) {
out_arg.type = LiteType::GetTensorTy(TARGET(kHost),
subgraph_arg_type->precision(),
DATALAYOUT(kNHWC));
VLOG(5) << "unused output node type: " << out_arg.name
<< out_node_type->name();
} else {
out_arg.type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
VLOG(5) << "output node type: " << out_arg.name
<< out_node_type->name();
}
}
const auto target = out_node->AsArg().type->target();
const auto precision = out_node->AsArg().type->precision();
......@@ -610,79 +642,88 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) {
}
}
std::string CheckInputAndInsert(Scope* scope,
cpp::BlockDesc* block_desc,
const std::string& input_name,
const Type* tensor_type,
const Type* subgraph_type) {
std::pair<bool, std::string> CheckInputAndInsert(Scope* scope,
cpp::BlockDesc* block_desc,
const std::string& input_name,
const Type* tensor_type,
const Type* subgraph_type) {
auto cur_node = input_name;
if (DataLayoutCompatible(*tensor_type, *subgraph_type)) {
bool do_insert = false;
if (!DataLayoutCompatible(*tensor_type, *subgraph_type)) {
auto layout_op = block_desc->AddOp<cpp::OpDesc>();
auto layout_arg_name = string_format("%s/layout", cur_node.c_str());
scope->Var(layout_arg_name);
VLOG(5) << "insert layout in subgraph, arg tensor name: "
VLOG(5) << "insert layout for subgraph input, arg tensor name: "
<< layout_arg_name;
layout_op->SetType("layout");
layout_op->SetInput("Input", {cur_node});
layout_op->SetOutput("Out", {layout_arg_name});
cur_node = layout_arg_name;
do_insert = true;
}
if (PrecisionCompatible(*tensor_type, *subgraph_type) &&
if (!PrecisionCompatible(*tensor_type, *subgraph_type) &&
tensor_type->precision() != PRECISION(kInt8)) {
auto cast_op = block_desc->AddOp<cpp::OpDesc>();
auto cast_arg_name = string_format("%s/cast", cur_node.c_str());
scope->Var(cast_arg_name);
VLOG(5) << "insert cast in subgraph, arg tensor name: " << cast_arg_name;
VLOG(5) << "insert cast for subgraph input, arg tensor name: "
<< cast_arg_name;
cast_op->SetType("cast");
cast_op->SetAttr<int>("in_dtype", 4); // FP32
cast_op->SetAttr<int>("out_dtype", 5); // FP16
cast_op->SetAttr<int>("in_dtype", 5); // FP32
cast_op->SetAttr<int>("out_dtype", 4); // FP16
cast_op->SetInput("X", {cur_node});
cast_op->SetOutput("Out", {cast_arg_name});
cur_node = cast_arg_name;
do_insert = true;
}
return cur_node;
return std::make_pair(do_insert, cur_node);
}
std::string CheckOutputAndInsert(Scope* scope,
cpp::BlockDesc* block_desc,
const std::string& output_name,
const Type* tensor_type,
const Type* subgraph_type) {
std::pair<bool, std::string> CheckOutputAndInsert(
Scope* scope,
cpp::BlockDesc* block_desc,
const std::string& output_name,
const Type* tensor_type,
const Type* subgraph_type) {
auto cur_node = output_name;
bool do_insert = false;
cpp::OpDesc *layout_op = nullptr, *cast_op = nullptr;
// subgraph -> cast -> layout -> output
if (PrecisionCompatible(*tensor_type, *subgraph_type)) {
if (!PrecisionCompatible(*tensor_type, *subgraph_type)) {
cast_op = block_desc->AddOp<cpp::OpDesc>();
cast_op->SetType("cast");
cast_op->SetAttr<int>("in_dtype", 5); // FP16
cast_op->SetAttr<int>("out_dtype", 4); // FP32
cast_op->SetAttr<int>("in_dtype", 4); // FP16
cast_op->SetAttr<int>("out_dtype", 5); // FP32
do_insert = true;
}
if (DataLayoutCompatible(*tensor_type, *subgraph_type)) {
if (!DataLayoutCompatible(*tensor_type, *subgraph_type)) {
auto layout_arg_name = string_format("%s/layout", cur_node.c_str());
scope->Var(layout_arg_name);
VLOG(5) << "insert layout in subgraph, arg tensor name: "
VLOG(5) << "insert layout for subgraph output, arg tensor name: "
<< layout_arg_name;
layout_op = block_desc->AddOp<cpp::OpDesc>();
layout_op->SetType("layout");
layout_op->SetInput("Input", {layout_arg_name});
layout_op->SetOutput("Out", {cur_node});
cur_node = layout_arg_name;
do_insert = true;
}
if (cast_op) {
auto cast_arg_name = string_format("%s/cast", cur_node.c_str());
scope->Var(cast_arg_name);
VLOG(5) << "insert cast in subgraph, arg tensor name: " << cast_arg_name;
VLOG(5) << "insert cast for subgraph output, arg tensor name: "
<< cast_arg_name;
cast_op->SetInput("X", {cast_arg_name});
cast_op->SetOutput("Out", {cur_node});
cur_node = cast_arg_name;
}
return cur_node;
return std::make_pair(do_insert, cur_node);
}
// insert cast op on mlu, to avoid cast on cpu, invoke before first run
......@@ -711,22 +752,28 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
auto input_name = input->AsArg().name;
if (!(input->AsArg().is_weight || input->AsArg().is_persist)) {
i_names.emplace_back(input_name);
node_replace[input_name] = CheckInputAndInsert(op->scope(),
new_block_desc,
input_name,
input->AsArg().type,
subgraph_type);
auto ret = CheckInputAndInsert(op->scope(),
new_block_desc,
input_name,
input->AsArg().type,
subgraph_type);
if (ret.first) {
node_replace[input_name] = ret.second;
}
}
}
for (auto& output : subgraph_node->outlinks) {
auto output_name = output->AsArg().name;
if (!(output->AsArg().is_weight || output->AsArg().is_persist)) {
o_names.emplace_back(output_name);
node_replace[output_name] = CheckOutputAndInsert(op->scope(),
block_desc,
output_name,
output->AsArg().type,
subgraph_type);
auto ret = CheckOutputAndInsert(op->scope(),
block_desc,
output_name,
output->AsArg().type,
subgraph_type);
if (ret.first) {
node_replace[output_name] = ret.second;
}
}
}
......@@ -736,34 +783,38 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
auto new_desc = new_block_desc->AddOp<cpp::OpDesc>();
*new_desc = *desc;
auto op_input_args = new_desc->InputArgumentNames();
for (auto& input_arg : op_input_args) {
auto op_input = new_desc->Input(input_arg);
for (auto& it : i_names) {
auto index = std::find(op_input.cbegin(), op_input.cend(), it);
if (index != op_input.cend()) {
index = op_input.erase(index);
op_input.emplace(index, node_replace.at(it));
VLOG(4) << new_desc->Type() << "] change input from " << it << " to "
<< node_replace.at(it);
if (desc->Type() != "layout" && desc->Type() != "cast") {
auto op_input_args = new_desc->InputArgumentNames();
for (auto& input_arg : op_input_args) {
auto op_input = new_desc->Input(input_arg);
for (auto& it : i_names) {
auto index = std::find(op_input.cbegin(), op_input.cend(), it);
if (index != op_input.cend() &&
node_replace.find(it) != node_replace.end()) {
index = op_input.erase(index);
op_input.emplace(index, node_replace.at(it));
VLOG(4) << new_desc->Type() << "] change input from " << it
<< " to " << node_replace.at(it);
}
}
new_desc->SetInput(input_arg, op_input);
}
new_desc->SetInput(input_arg, op_input);
}
auto op_output_args = new_desc->OutputArgumentNames();
for (auto& output_arg : op_output_args) {
auto op_output = new_desc->Output(output_arg);
for (auto& it : o_names) {
auto index = std::find(op_output.cbegin(), op_output.cend(), it);
if (index != op_output.cend()) {
index = op_output.erase(index);
op_output.emplace(index, node_replace.at(it));
VLOG(4) << new_desc->Type() << "] change output from " << it << " to "
<< node_replace.at(it);
auto op_output_args = new_desc->OutputArgumentNames();
for (auto& output_arg : op_output_args) {
auto op_output = new_desc->Output(output_arg);
for (auto& it : o_names) {
auto index = std::find(op_output.cbegin(), op_output.cend(), it);
if (index != op_output.cend() &&
node_replace.find(it) != node_replace.end()) {
index = op_output.erase(index);
op_output.emplace(index, node_replace.at(it));
VLOG(4) << new_desc->Type() << "] change output from " << it
<< " to " << node_replace.at(it);
}
}
new_desc->SetOutput(output_arg, op_output);
}
new_desc->SetOutput(output_arg, op_output);
}
}
op->SetSubBlock(new_block_desc);
......
......@@ -40,26 +40,21 @@ int CastConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(graph->HasNode(x_var_name));
auto x_tensor = graph->GetNode(x_var_name);
cnmlDataType_t data_type;
if (out_dtype == 4) {
data_type = CNML_DATA_FLOAT16;
} else if (out_dtype == 5) {
data_type = CNML_DATA_FLOAT32;
} else {
CHECK(0) << "Unsupported data_type";
}
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, data_type);
cnmlDataType_t out_type;
cnmlCastType_t cast_type;
if (in_dtype == 4 && out_dtype == 5) {
cast_type = CNML_CAST_FLOAT16_TO_FLOAT32;
out_type = CNML_DATA_FLOAT32;
} else if (in_dtype == 5 && out_dtype == 4) {
cast_type = CNML_CAST_FLOAT32_TO_FLOAT16;
out_type = CNML_DATA_FLOAT16;
} else {
CHECK(0) << "Unsupported cast type";
}
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, out_type);
cnmlBaseOp_t cast_op;
CNML_CALL(cnmlCreateCastOp(&cast_op,
cast_type,
......
......@@ -60,7 +60,7 @@ int LayoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(0) << "Unsupport shape";
}
output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, x_tensor->dtype());
VLOG(3) << "layout transpose nchw to nhwc" << std::endl;
} else {
switch (x_dims.size()) {
......@@ -84,7 +84,7 @@ int LayoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
output_dims,
CNML_TENSOR,
CNML_NCHW,
graph->FPType(),
x_tensor->dtype(),
CNML_NCHW);
}
cnmlBaseOp_t layout_op;
......
......@@ -41,6 +41,7 @@ class IoCopyHostToMluCompute
auto mem_size = param.x->memory_size();
// LOG(INFO) << "copy size " << mem_size;
auto* data = param.y->mutable_data(TARGET(kMLU), mem_size);
param.y->set_precision(param.x->precision());
CopyFromHostSync(data, param.x->raw_data(), mem_size);
}
......
......@@ -147,7 +147,7 @@ class SubgraphEngine : public subgraph::Engine {
origin_itensors_.clear();
origin_otensors_.clear();
auto data_order = block_desc_->GetOp<cpp::OpDesc>(0)->Type() == "cast"
auto data_order = block_desc_->GetOp<cpp::OpDesc>(0)->Type() == "layout"
? CNML_NCHW
: CNML_NHWC;
// Convert all of input data vars and added into the MLU IR graph
......@@ -166,6 +166,7 @@ class SubgraphEngine : public subgraph::Engine {
}
CHECK(input_tensor);
VLOG(4) << "subgraph input tensor " << input_name << std::endl;
auto input_node = graph->AddNode(input_name,
input_tensor->dims().Vectorize(),
CNML_TENSOR,
......@@ -217,6 +218,7 @@ class SubgraphEngine : public subgraph::Engine {
graph->AddOutput(graph->GetNode(output_name));
auto output_tensor = scope_->FindMutableTensor(output_name);
origin_otensors_.push_back(output_tensor);
VLOG(4) << "subgraph output tensor " << output_name << std::endl;
// auto node = graph->GetNode(output_name);
// CHECK(p_data);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册