diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index f3c32bc8292f15ca87fbd049ef7cc5868e1caee5..61021e62a873a9e60eb3ff95c43d742610fd1161 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -69,7 +69,8 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, for (auto& kernel : kernels) { if (op_type == "cast") { const Type* in_arg_ty = kernel->GetInputDeclType("X"); - if (PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type)) { + if (PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type) && + DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type)) { is_found = true; } } else if (op_type == "layout") { @@ -564,6 +565,16 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { old_type->precision(), paddle::lite_api::DataLayoutType::kNHWC, old_type->device()); + // modify inst feed to NHWC, while set_mlu_input_layout(kNHWC) + // invoked, to keep consistent with actual data layout + auto place = node.AsStmt().place(); + place.layout = DATALAYOUT(kNHWC); + std::vector valid_places = {place}; + auto updated_op_info = *node.AsStmt().op_info(); + node.AsStmt().ResetOp(updated_op_info, valid_places, nullptr); + auto kernel = &(node.AsStmt().picked_kernel()); + VLOG(4) << "kernel info: " << kernel->name(); + node.AsStmt().op()->AttachKernel(kernel); } } } diff --git a/lite/core/mir/subgraph/subgraph_pass.cc b/lite/core/mir/subgraph/subgraph_pass.cc index a378e2ed642f3da940636c1bcf8637e8bf0d8f88..8e537d6c0fcd84f503a3349bad0b45eea12b6faf 100644 --- a/lite/core/mir/subgraph/subgraph_pass.cc +++ b/lite/core/mir/subgraph/subgraph_pass.cc @@ -95,7 +95,26 @@ void MLUSubgraphPass::Apply(const std::unique_ptr& graph) { ++it; } } + // add x86 NHWC place + std::vector precisions{PRECISION(kFloat), + PRECISION(kFP16)}; + if (lite::TargetWrapperMlu::UseFirstConv()) + precisions.emplace_back(PRECISION(kInt8)); + for (auto& prec : precisions) { + auto is_x86_nhwc = [prec](const Place& it) { + return it.layout == DATALAYOUT(kNHWC) && it.target == TARGET(kX86) && + it.precision == prec; + }; + if (std::find_if(v_places.cbegin(), v_places.cend(), is_x86_nhwc) == + v_places.end()) { + v_places.emplace_back(Place{TARGET(kX86), prec, DATALAYOUT(kNHWC)}); + } + } graph->SetValidPlaces(v_places); + VLOG(4) << "valid places after modified:"; + for (auto& p : v_places) { + VLOG(4) << p.DebugString(); + } #endif std::unordered_set supported_lists; diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index c45c2e8dc1f2c9b5db1e472c025f1b48e8f74118..3e185f9d8acc838d75ac4770ed28cfbecc072837 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -182,6 +182,9 @@ KernelRegistry::KernelRegistry() INIT_FOR(kX86, kFloat, kNCHW); INIT_FOR(kX86, kFP16, kNCHW); INIT_FOR(kX86, kInt8, kNCHW); + INIT_FOR(kX86, kFloat, kNHWC); + INIT_FOR(kX86, kFP16, kNHWC); + INIT_FOR(kX86, kInt8, kNHWC); INIT_FOR(kX86, kAny, kNCHW); INIT_FOR(kX86, kAny, kAny); INIT_FOR(kX86, kInt64, kNCHW); diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index 5a89d14ad299c219812015fdc6eed14e83ac493d..f46451ec049fe439614986f3db89d3786b54794b 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -126,6 +126,15 @@ class KernelRegistry final { KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // KernelRegistryForTarget *, // diff --git a/lite/kernels/x86/cast_compute.cc b/lite/kernels/x86/cast_compute.cc index bbb63e595269667dedebeafd83cc962d1d0fb878..d61de28cb58856e63a558374a576f660cc3b1f88 100644 --- a/lite/kernels/x86/cast_compute.cc +++ b/lite/kernels/x86/cast_compute.cc @@ -20,7 +20,10 @@ REGISTER_LITE_KERNEL(cast, kNCHW, paddle::lite::kernels::x86::CastCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kX86), + PRECISION(kFloat), + DATALAYOUT(kAny))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); @@ -31,6 +34,9 @@ REGISTER_LITE_KERNEL( kNCHW, paddle::lite::kernels::x86::CastCompute<::paddle::lite::fluid::float16>, fp16_to_any) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFP16))}) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kX86), + PRECISION(kFP16), + DATALAYOUT(kAny))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize();