From fffb2fe60fd669fb9f9a2e17abf85c8fd64204ac Mon Sep 17 00:00:00 2001 From: jackzhang235 Date: Thu, 2 Apr 2020 12:00:53 +0000 Subject: [PATCH] move mlu passes forward --- lite/core/mir/mlu_postprocess_pass.cc | 79 ++++++++++++++++++++------- lite/core/mir/mlu_postprocess_pass.h | 2 + lite/core/op_registry.cc | 2 + lite/core/op_registry.h | 3 + lite/core/optimizer.h | 8 +-- lite/kernels/mlu/io_copy_compute.cc | 40 +++++++++++--- lite/kernels/mlu/layout_compute.cc | 22 ++++---- lite/kernels/mlu/layout_compute.h | 4 +- lite/kernels/mlu/subgraph_compute.cc | 20 +++++-- 9 files changed, 130 insertions(+), 50 deletions(-) diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index 4ea24390d6..23303d6f60 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -105,13 +105,8 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, // we pick the kernel cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op); auto& stmt = cast_inst->AsStmt(); - if (op_type == "layout") { - stmt.picked_kernel().SetContext( - ContextScheduler::Global().NewContext(TARGET(kX86))); - } else { - stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext( - stmt.picked_kernel().target())); - } + stmt.picked_kernel().SetContext( + ContextScheduler::Global().NewContext(stmt.picked_kernel().target())); break; } } @@ -185,7 +180,8 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); if (DataLayoutCompatible(*in_arg_ty, *cast_type) && - DataLayoutCompatible(*out_arg_ty, *cur_node->AsArg().type)) { + DataLayoutCompatible(*out_arg_ty, *cur_node->AsArg().type) && + PrecisionCompatibleTo(*in_arg_ty, *cast_type)) { is_found = true; } } else if (op_type == "io_copy") { @@ -203,13 +199,8 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, // we pick the kernel cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op); auto& stmt = cast_inst->AsStmt(); - if (op_type == "layout") { - stmt.picked_kernel().SetContext( - ContextScheduler::Global().NewContext(TARGET(kX86))); - } else { - stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext( - stmt.picked_kernel().target())); - } + stmt.picked_kernel().SetContext( + ContextScheduler::Global().NewContext(stmt.picked_kernel().target())); break; } } @@ -517,6 +508,50 @@ void MLUPostprocessPass::GatherAndModifyFirstConvNodes(SSAGraph* graph) { } } +void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) { + for (auto& node : graph->mutable_nodes()) { + if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") { + for (auto& in_node : node.inlinks) { + const auto* in_node_type = in_node->AsArg().type; + VLOG(4) << "MLU subgraph input type: " << in_node->AsArg().name + << *in_node_type; + if (in_node->AsArg().is_weight || in_node->AsArg().is_persist) { + CHECK(in_node_type->target() == TARGET(kHost) && + in_node_type->precision() == PRECISION(kAny) && + in_node_type->layout() == DATALAYOUT(kNCHW)) + << "MLU subgraph unexpected persistent input type!"; + in_node->AsArg().type = LiteType::GetTensorTy( + TARGET(kMLU), PRECISION(kAny), DATALAYOUT(kNHWC)); + } else { + CHECK((in_node_type->target() == TARGET(kHost) || + in_node_type->target() == TARGET(kX86)) && + in_node_type->precision() == PRECISION(kFloat) && + in_node_type->layout() == DATALAYOUT(kNCHW)) + << "MLU subgraph unexpected common input type!"; + } + } + for (auto& out_node : node.outlinks) { + const auto* out_node_type = out_node->AsArg().type; + VLOG(4) << "MLU subgraph output type: " << out_node->AsArg().name + << *out_node_type; + if (out_node->AsArg().is_weight || out_node->AsArg().is_persist) { + CHECK(out_node_type->target() == TARGET(kHost) && + out_node_type->precision() == PRECISION(kAny) && + out_node_type->layout() == DATALAYOUT(kNCHW)) + << "MLU subgraph unexpected persistent input type!"; + out_node->AsArg().type = LiteType::GetTensorTy( + TARGET(kMLU), PRECISION(kAny), DATALAYOUT(kNHWC)); + } else { + CHECK(out_node_type->precision() == PRECISION(kFloat)) + << "MLU subgraph unexpected common output type!"; + out_node->AsArg().type = LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + } + } + } + } +} + void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { for (auto& node : graph->mutable_nodes()) { if (!node.IsStmt()) continue; @@ -562,14 +597,16 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { } void MLUPostprocessPass::Apply(const std::unique_ptr& graph) { -// currently for non-persistent input and output args, mlu subgraph op -// only support float16/float32 data type + // currently for non-persistent input and output args, mlu subgraph op + // only support float16/float32 data type -// in two situations as folllows: -// 1: feed->arg_in->subgraph->... 2: ...->subgraph->arg_out->fetch; -// arg_in and arg_out are assumed to be NHWC which user should be aware of. -// Thus here we change these args' layout to NHWC + // in two situations as folllows: + // 1: feed->arg_in->subgraph->... 2: ...->subgraph->arg_out->fetch; + // arg_in and arg_out are assumed to be NHWC which user should be aware of. + // Thus here we change these args' layout to NHWC #ifdef LITE_WITH_MLU + ModifyInputOutputDataType(graph.get()); + if (lite::TargetWrapperMlu::InputLayout() == DATALAYOUT(kNHWC)) { ModifyLayout(graph.get()); } diff --git a/lite/core/mir/mlu_postprocess_pass.h b/lite/core/mir/mlu_postprocess_pass.h index 688dd06fb5..eddcf67c0b 100644 --- a/lite/core/mir/mlu_postprocess_pass.h +++ b/lite/core/mir/mlu_postprocess_pass.h @@ -79,6 +79,8 @@ class MLUPostprocessPass : public ProgramPass { const Type** arg_type, SSAGraph* graph); + void ModifyInputOutputDataType(SSAGraph* graph); + void ModifyLayout(SSAGraph* graph); bool NeedInsert(Node* node, const Type* inst_type); diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 0c8d42f4e2..c45c2e8dc1 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -180,6 +180,8 @@ KernelRegistry::KernelRegistry() INIT_FOR(kHost, kInt64, kAny); INIT_FOR(kX86, kFloat, kNCHW); + INIT_FOR(kX86, kFP16, kNCHW); + INIT_FOR(kX86, kInt8, kNCHW); INIT_FOR(kX86, kAny, kNCHW); INIT_FOR(kX86, kAny, kAny); INIT_FOR(kX86, kInt64, kNCHW); diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index 65279b74c5..5a89d14ad2 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -120,6 +120,9 @@ class KernelRegistry final { KernelRegistryForTarget *, // + KernelRegistryForTarget *, // KernelRegistryForTarget *, // diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 2fb2799682..2d36ea3966 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -90,8 +90,12 @@ class Optimizer { "xpu_subgraph_pass", "bm_subgraph_pass", "rknpu_subgraph_pass", + "mlu_subgraph_pass", + "static_kernel_pick_pass", // pick original kernel from graph "variable_place_inference_pass", // inference arg/var's + + "mlu_postprocess_pass", // info(target/precision/layout/device) // using kernel info "argument_type_display_pass", // debug pass: show arg-type-node's @@ -121,13 +125,9 @@ class Optimizer { "variable_place_inference_pass", // "argument_type_display_pass", - "mlu_subgraph_pass", - "runtime_context_assign_pass", "argument_type_display_pass", - "mlu_postprocess_pass", - "memory_optimize_pass"}}; if (passes.size() == 1) { diff --git a/lite/kernels/mlu/io_copy_compute.cc b/lite/kernels/mlu/io_copy_compute.cc index 7520d50034..4bdbf5e23c 100644 --- a/lite/kernels/mlu/io_copy_compute.cc +++ b/lite/kernels/mlu/io_copy_compute.cc @@ -102,8 +102,14 @@ REGISTER_LITE_KERNEL( kNHWC, paddle::lite::kernels::mlu::IoCopyHostToMluCompute, host_to_device_kFloat) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kMLU))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); REGISTER_LITE_KERNEL( @@ -113,8 +119,14 @@ REGISTER_LITE_KERNEL( kNHWC, paddle::lite::kernels::mlu::IoCopyHostToMluCompute, host_to_device_kFP16) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kMLU))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); REGISTER_LITE_KERNEL( @@ -124,8 +136,14 @@ REGISTER_LITE_KERNEL( kNHWC, paddle::lite::kernels::mlu::IoCopyMluToHostCompute, device_to_host_kFloat) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kMLU))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); REGISTER_LITE_KERNEL( @@ -135,6 +153,12 @@ REGISTER_LITE_KERNEL( kNHWC, paddle::lite::kernels::mlu::IoCopyMluToHostCompute, device_to_host_kFP16) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kMLU))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); diff --git a/lite/kernels/mlu/layout_compute.cc b/lite/kernels/mlu/layout_compute.cc index d4e16734d6..42b12740ff 100644 --- a/lite/kernels/mlu/layout_compute.cc +++ b/lite/kernels/mlu/layout_compute.cc @@ -24,9 +24,9 @@ namespace mlu {} // namespace mlu REGISTER_LITE_KERNEL( layout, - kMLU, + kX86, kFloat, - kNHWC, + kNCHW, paddle::lite::kernels::mlu::LayoutNhwcToNchwCompute, def_layout_nhwc2nchw_fp32) .BindInput("Input", @@ -41,9 +41,9 @@ REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL( layout, - kMLU, + kX86, kFP16, - kNHWC, + kNCHW, paddle::lite::kernels::mlu::LayoutNhwcToNchwCompute, def_layout_nhwc2nchw_fp16) .BindInput("Input", @@ -58,9 +58,9 @@ REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL( layout, - kMLU, + kX86, kFloat, - kNHWC, + kNCHW, paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute, def_layout_nchw2nhwc_fp32) .BindInput("Input", @@ -75,9 +75,9 @@ REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL( layout, - kMLU, + kX86, kFP16, - kNHWC, + kNCHW, paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute, def_layout_nchw2nhwc_fp16) .BindInput("Input", @@ -92,11 +92,11 @@ REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL( layout, - kMLU, + kX86, kInt8, - kNHWC, + kNCHW, paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute, - def_layout_nchw2nhwc_fp32_int8) + def_layout_nchw2nhwc_int8) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt8), diff --git a/lite/kernels/mlu/layout_compute.h b/lite/kernels/mlu/layout_compute.h index 8e98f1ba6e..75c3a90960 100644 --- a/lite/kernels/mlu/layout_compute.h +++ b/lite/kernels/mlu/layout_compute.h @@ -73,7 +73,7 @@ inline void LayoutTransCompute(const int dim, template class LayoutNchwToNhwcCompute - : public KernelLite { + : public KernelLite { public: using param_t = operators::LayoutParam; @@ -122,7 +122,7 @@ class LayoutNchwToNhwcCompute template class LayoutNhwcToNchwCompute - : public KernelLite { + : public KernelLite { public: using param_t = operators::LayoutParam; diff --git a/lite/kernels/mlu/subgraph_compute.cc b/lite/kernels/mlu/subgraph_compute.cc index 73ca9dcc20..450031021d 100644 --- a/lite/kernels/mlu/subgraph_compute.cc +++ b/lite/kernels/mlu/subgraph_compute.cc @@ -36,8 +36,14 @@ REGISTER_LITE_KERNEL( kNHWC, paddle::lite::kernels::mlu::SubgraphCompute, def_kFloat) - .BindInput("Inputs", {LiteType::GetTensorTy(TARGET(kMLU))}) - .BindOutput("Outputs", {LiteType::GetTensorTy(TARGET(kMLU))}) + .BindInput("Inputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Outputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); REGISTER_LITE_KERNEL( @@ -47,6 +53,12 @@ REGISTER_LITE_KERNEL( kNHWC, paddle::lite::kernels::mlu::SubgraphCompute, def_FP16) - .BindInput("Inputs", {LiteType::GetTensorTy(TARGET(kMLU))}) - .BindOutput("Outputs", {LiteType::GetTensorTy(TARGET(kMLU))}) + .BindInput("Inputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Outputs", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); -- GitLab