未验证 提交 ded3e7af 编写于 作者: Z zhaoying9105 提交者: GitHub

add int32 mlu io copy out support & do not ModifyInputOutputDataType for int32 (#106)

1 add int32 mlu io copy out support;
2 do not ModifyInputOutputDataType for int32
上级 8b5ea414
......@@ -551,7 +551,8 @@ void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) {
} else {
CHECK((in_node_type->target() == TARGET(kHost) ||
in_node_type->target() == TARGET(kX86)) &&
in_node_type->precision() == PRECISION(kFloat) &&
(in_node_type->precision() == PRECISION(kFloat) ||
in_node_type->precision() == PRECISION(kInt32)) &&
in_node_type->layout() == DATALAYOUT(kNCHW))
<< "MLU subgraph unexpected common input type!";
}
......@@ -574,7 +575,8 @@ void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) {
out_arg.type = LiteType::GetTensorTy(
TARGET(kMLU), PRECISION(kAny), DATALAYOUT(kNHWC));
} else {
CHECK(out_node_type->precision() == PRECISION(kFloat))
CHECK(out_node_type->precision() == PRECISION(kFloat) ||
out_node_type->precision() == PRECISION(kInt32))
<< "MLU subgraph unexpected common output type!";
if (out_node->outlinks.empty()) {
out_arg.type = LiteType::GetTensorTy(TARGET(kHost),
......@@ -584,7 +586,7 @@ void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) {
<< out_node_type->name();
} else {
out_arg.type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
TARGET(kHost), out_node_type->precision(), DATALAYOUT(kNCHW));
VLOG(4) << "output node type: " << out_arg.name
<< out_node_type->name();
}
......@@ -706,7 +708,8 @@ std::pair<bool, std::string> CheckOutputAndInsert(
size_t cast_idx = 0;
// subgraph -> cast -> layout -> output
if (!PrecisionCompatible(*tensor_type, *subgraph_type)) {
if (!PrecisionCompatible(*tensor_type, *subgraph_type) &&
tensor_type->precision() != PRECISION(kInt32)) {
cast_op = block_desc->AddOp<cpp::OpDesc>();
cast_idx = block_desc->OpsSize() - 1;
CHECK_EQ(cast_op, block_desc->GetOp<cpp::OpDesc>(cast_idx));
......
......@@ -151,6 +151,23 @@ REGISTER_LITE_KERNEL(
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(
io_copy,
kMLU,
kInt32,
kNHWC,
paddle::lite::kernels::mlu::IoCopyMluToHostCompute<PRECISION(kInt32)>,
device_to_host_kInt32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(
io_copy,
kMLU,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册