From 53b6e51b1020f806ecd0c98e0cdc4e3b20c8a75b Mon Sep 17 00:00:00 2001 From: dingminghui Date: Fri, 24 Apr 2020 19:36:54 +0800 Subject: [PATCH] fix(place): add x86 nhwc place in subgraph_pass Since input tensor is modified to NHWC, which is not compatible to cast kernel in NCHW layout, type_layout_cast_pass will insert a layout instruction to transform data. Use NHWC cast kernel to avoid cast layout --- lite/core/mir/mlu_postprocess_pass.cc | 13 ++++++++++++- lite/core/mir/subgraph/subgraph_pass.cc | 19 +++++++++++++++++++ lite/core/op_registry.cc | 3 +++ lite/core/op_registry.h | 9 +++++++++ lite/kernels/x86/cast_compute.cc | 10 ++++++++-- 5 files changed, 51 insertions(+), 3 deletions(-) diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index f3c32bc829..61021e62a8 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -69,7 +69,8 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, for (auto& kernel : kernels) { if (op_type == "cast") { const Type* in_arg_ty = kernel->GetInputDeclType("X"); - if (PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type)) { + if (PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type) && + DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type)) { is_found = true; } } else if (op_type == "layout") { @@ -564,6 +565,16 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { old_type->precision(), paddle::lite_api::DataLayoutType::kNHWC, old_type->device()); + // modify inst feed to NHWC, while set_mlu_input_layout(kNHWC) + // invoked, to keep consistent with actual data layout + auto place = node.AsStmt().place(); + place.layout = DATALAYOUT(kNHWC); + std::vector valid_places = {place}; + auto updated_op_info = *node.AsStmt().op_info(); + node.AsStmt().ResetOp(updated_op_info, valid_places, nullptr); + auto kernel = &(node.AsStmt().picked_kernel()); + VLOG(4) << "kernel info: " << kernel->name(); + node.AsStmt().op()->AttachKernel(kernel); } } } diff --git a/lite/core/mir/subgraph/subgraph_pass.cc b/lite/core/mir/subgraph/subgraph_pass.cc index a378e2ed64..8e537d6c0f 100644 --- a/lite/core/mir/subgraph/subgraph_pass.cc +++ b/lite/core/mir/subgraph/subgraph_pass.cc @@ -95,7 +95,26 @@ void MLUSubgraphPass::Apply(const std::unique_ptr& graph) { ++it; } } + // add x86 NHWC place + std::vector precisions{PRECISION(kFloat), + PRECISION(kFP16)}; + if (lite::TargetWrapperMlu::UseFirstConv()) + precisions.emplace_back(PRECISION(kInt8)); + for (auto& prec : precisions) { + auto is_x86_nhwc = [prec](const Place& it) { + return it.layout == DATALAYOUT(kNHWC) && it.target == TARGET(kX86) && + it.precision == prec; + }; + if (std::find_if(v_places.cbegin(), v_places.cend(), is_x86_nhwc) == + v_places.end()) { + v_places.emplace_back(Place{TARGET(kX86), prec, DATALAYOUT(kNHWC)}); + } + } graph->SetValidPlaces(v_places); + VLOG(4) << "valid places after modified:"; + for (auto& p : v_places) { + VLOG(4) << p.DebugString(); + } #endif std::unordered_set supported_lists; diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index c45c2e8dc1..3e185f9d8a 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -182,6 +182,9 @@ KernelRegistry::KernelRegistry() INIT_FOR(kX86, kFloat, kNCHW); INIT_FOR(kX86, kFP16, kNCHW); INIT_FOR(kX86, kInt8, kNCHW); + INIT_FOR(kX86, kFloat, kNHWC); + INIT_FOR(kX86, kFP16, kNHWC); + INIT_FOR(kX86, kInt8, kNHWC); INIT_FOR(kX86, kAny, kNCHW); INIT_FOR(kX86, kAny, kAny); INIT_FOR(kX86, kInt64, kNCHW); diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index 5a89d14ad2..f46451ec04 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -126,6 +126,15 @@ class KernelRegistry final { KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // KernelRegistryForTarget *, // diff --git a/lite/kernels/x86/cast_compute.cc b/lite/kernels/x86/cast_compute.cc index bbb63e5952..d61de28cb5 100644 --- a/lite/kernels/x86/cast_compute.cc +++ b/lite/kernels/x86/cast_compute.cc @@ -20,7 +20,10 @@ REGISTER_LITE_KERNEL(cast, kNCHW, paddle::lite::kernels::x86::CastCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kX86), + PRECISION(kFloat), + DATALAYOUT(kAny))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); @@ -31,6 +34,9 @@ REGISTER_LITE_KERNEL( kNCHW, paddle::lite::kernels::x86::CastCompute<::paddle::lite::fluid::float16>, fp16_to_any) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFP16))}) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kX86), + PRECISION(kFP16), + DATALAYOUT(kAny))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); -- GitLab