From 3785aab8b1d2d661a95cfec216ae12c2d3c66fee Mon Sep 17 00:00:00 2001 From: jackzhang235 Date: Sat, 28 Mar 2020 09:28:58 +0000 Subject: [PATCH] support mlu set_input_layout --- lite/api/paddle_api.cc | 13 ++++++------- lite/api/paddle_api.h | 13 ++++++------- lite/api/python/pybind/pybind.cc | 11 ++++++----- lite/core/device_info.cc | 3 ++- lite/core/device_info.h | 2 +- lite/core/mir/mlu_postprocess_pass.cc | 25 ++++++++++++++++-------- lite/core/mir/mlu_postprocess_pass.h | 2 +- lite/kernels/mlu/bridges/utility.h | 2 +- lite/kernels/mlu/layout_compute.cc | 17 ++++++++++++++++ lite/kernels/mlu/layout_compute.h | 28 +++++++++++++++++++++++---- 10 files changed, 81 insertions(+), 35 deletions(-) diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index 6f4de971a7..91edb2cda7 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -203,23 +203,22 @@ void ConfigBase::set_threads(int threads) { #endif } -void CxxConfig::mlu_set_mlu_core_version( - lite_api::MLUCoreVersion core_version) { +void CxxConfig::set_mlu_core_version(lite_api::MLUCoreVersion core_version) { mlu_core_version_ = core_version; } -void CxxConfig::mlu_set_mlu_core_number(int core_number) { +void CxxConfig::set_mlu_core_number(int core_number) { mlu_core_number_ = core_number; } -void CxxConfig::mlu_set_input_layout()(DataLayoutType layout) { +void CxxConfig::set_mlu_input_layout(DataLayoutType layout) { mlu_input_layout_ = layout; } -void CxxConfig::mlu_set_use_first_conv(bool use_first_conv) { +void CxxConfig::set_mlu_use_first_conv(bool use_first_conv) { mlu_use_first_conv_ = use_first_conv; } -void CxxConfig::mlu_set_first_conv_mean(const std::vector &mean) { +void CxxConfig::set_mlu_first_conv_mean(const std::vector &mean) { mlu_first_conv_mean_ = mean; } -void CxxConfig::mlu_set_first_conv_std(const std::vector &std) { +void CxxConfig::set_mlu_first_conv_std(const std::vector &std) { mlu_first_conv_std_ = std; } lite_api::MLUCoreVersion CxxConfig::mlu_core_version() const { diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index ba7fbc58b0..0cb60bf84f 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -136,7 +136,6 @@ class LITE_API CxxConfig : public ConfigBase { #ifdef LITE_WITH_X86 int x86_math_library_math_threads_ = 1; #endif - lite_api::MLUCoreVersion mlu_core_version_{lite_api::MLUCoreVersion::MLU_270}; int mlu_core_number_{1}; DataLayoutType mlu_input_layout_{DATALAYOUT(kNCHW)}; @@ -171,12 +170,12 @@ class LITE_API CxxConfig : public ConfigBase { } #endif - void mlu_set_mlu_core_version(lite_api::MLUCoreVersion core_version); - void mlu_set_mlu_core_number(int core_number); - void mlu_set_input_layout()(DataLayoutType layout); - void mlu_set_use_first_conv(bool use_first_conv); - void mlu_set_first_conv_mean(const std::vector& mean); - void mlu_set_first_conv_std(const std::vector& std); + void set_mlu_core_version(lite_api::MLUCoreVersion core_version); + void set_mlu_core_number(int core_number); + void set_mlu_input_layout(DataLayoutType layout); + void set_mlu_use_first_conv(bool use_first_conv); + void set_mlu_first_conv_mean(const std::vector& mean); + void set_mlu_first_conv_std(const std::vector& std); lite_api::MLUCoreVersion mlu_core_version() const; int mlu_core_number() const; diff --git a/lite/api/python/pybind/pybind.cc b/lite/api/python/pybind/pybind.cc index 942d7f8b54..5512e7bc43 100644 --- a/lite/api/python/pybind/pybind.cc +++ b/lite/api/python/pybind/pybind.cc @@ -128,11 +128,12 @@ void BindLiteCxxConfig(py::module *m) { .def("power_mode", &CxxConfig::power_mode); #endif #ifdef LITE_WITH_MLU - cxx_config.def("set_use_firstconv", &CxxConfig::set_use_firstconv) - .def("set_mean", &CxxConfig::set_mean) - .def("set_std", &CxxConfig::set_std) - .def("set_mlu_core_version", &CxxConfig::set_mlu_core_version) - .def("set_mlu_core_number", &CxxConfig::set_mlu_core_number); + cxx_config.def("set_mlu_core_version", &CxxConfig::set_mlu_core_version) + .def("set_mlu_core_number", &CxxConfig::set_mlu_core_number) + .def("set_mlu_input_layout", &CxxConfig::set_mlu_input_layout) + .def("set_mlu_use_first_conv", &CxxConfig::set_mlu_use_first_conv) + .def("set_mlu_first_conv_mean", &CxxConfig::set_mlu_first_conv_mean) + .def("set_mlu_first_conv_std", &CxxConfig::set_mlu_first_conv_std); #endif } diff --git a/lite/core/device_info.cc b/lite/core/device_info.cc index 915eb0affa..29ac96ed74 100644 --- a/lite/core/device_info.cc +++ b/lite/core/device_info.cc @@ -72,6 +72,7 @@ thread_local int DeviceInfo::mlu_core_number_{1}; thread_local bool DeviceInfo::use_first_conv_{false}; thread_local std::vector DeviceInfo::mean_vec_; thread_local std::vector DeviceInfo::std_vec_; +thread_local DataLayoutType DeviceInfo::input_layout_{DATALAYOUT(kNCHW)}; #endif #ifdef TARGET_IOS @@ -1123,7 +1124,7 @@ const std::vector& DeviceInfo::MeanVec() const { return mean_vec_; } const std::vector& DeviceInfo::StdVec() const { return std_vec_; } -const DataLayoutType InputLayout() const { return input_layout_; } +DataLayoutType DeviceInfo::InputLayout() const { return input_layout_; } #endif // LITE_WITH_MLU diff --git a/lite/core/device_info.h b/lite/core/device_info.h index be89092a43..4e7e4742c4 100644 --- a/lite/core/device_info.h +++ b/lite/core/device_info.h @@ -67,7 +67,7 @@ class DeviceInfo { bool UseFirstConv(); const std::vector& MeanVec() const; const std::vector& StdVec() const; - const DataLayoutType InputLayout() const; + DataLayoutType InputLayout() const; #endif void SetCache(int l1size, int l2size, int l3size); void SetArch(ARMArch arch) { arch_ = arch; } diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index 2b3e558d80..58f53420cb 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -74,7 +74,9 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); if (DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type) && - DataLayoutCompatible(*out_arg_ty, *cast_type)) { + DataLayoutCompatible(*out_arg_ty, *cast_type) && + // for first conv + PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type)) { is_found = true; } } else if (op_type == "io_copy") { @@ -121,7 +123,7 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, cast_arg->AsArg().type = cast_type; auto* var = inst_node->AsStmt().op()->scope()->Var(cast_arg_name); // for CastAfter manully set the tensor's type - var->GetMutable<::paddle::lite::Tensor>(); + var->GetMutable(); // create the stmt node auto* cast_inst = graph->NewInstructNode(); @@ -281,7 +283,7 @@ void MLUPostprocessPass::GetSubgraphOpArgType(Node* inst_node, // get subgraph's valid precision const auto& places = graph->valid_places(); - std::set<::paddle::lite_api::PrecisionType> prec_set; + std::set prec_set; for (const auto& place : places) { if (place.target == TARGET(kMLU)) { prec_set.insert(place.precision); @@ -474,13 +476,20 @@ bool MLUPostprocessPass::IsFirstConvNode(Node* arg_node) { return false; } -void MLUPostprocessPass::GatherFirstConvNodes(SSAGraph* graph) { +void MLUPostprocessPass::GatherAndModifyFirstConvNodes(SSAGraph* graph) { for (auto& node : graph->mutable_nodes()) { if (!node.IsStmt()) continue; if (node.AsStmt().op_type() == "feed") { for (auto& out : node.outlinks) { if (IsFirstConvNode(out)) { first_conv_nodes_.insert(out->AsArg().name); + // modify first conv nodes' type + const auto* old_type = out->AsArg().type; + out->AsArg().type = + LiteType::GetTensorTy(old_type->target(), + paddle::lite_api::PrecisionType::kInt8, + old_type->layout(), + old_type->device()); } } } @@ -504,7 +513,7 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { out->AsArg().type = LiteType::GetTensorTy(old_type->target(), old_type->precision(), - ::paddle::lite_api::DataLayoutType::kNHWC, + paddle::lite_api::DataLayoutType::kNHWC, old_type->device()); } } @@ -523,7 +532,7 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { inp->AsArg().type = LiteType::GetTensorTy(old_type->target(), old_type->precision(), - ::paddle::lite_api::DataLayoutType::kNHWC, + paddle::lite_api::DataLayoutType::kNHWC, old_type->device()); } } @@ -539,12 +548,12 @@ void MLUPostprocessPass::Apply(const std::unique_ptr& graph) { // 1: feed->arg_in->subgraph->... 2: ...->subgraph->arg_out->fetch; // arg_in and arg_out are assumed to be NHWC which user should be aware of. // Thus here we change these args' layout to NHWC - if (lite::DeviceInfo::Global().InputLayout() == DATALAYOUT(kNHWC) { + if (lite::DeviceInfo::Global().InputLayout() == DATALAYOUT(kNHWC)) { ModifyLayout(graph.get()); } if (lite::DeviceInfo::Global().UseFirstConv()) { - GatherFirstConvNodes(graph.get()); + GatherAndModifyFirstConvNodes(graph.get()); } // insert io_copy, layout and precision cast of subgraph's inputs and outputs diff --git a/lite/core/mir/mlu_postprocess_pass.h b/lite/core/mir/mlu_postprocess_pass.h index 34b449cca6..688dd06fb5 100644 --- a/lite/core/mir/mlu_postprocess_pass.h +++ b/lite/core/mir/mlu_postprocess_pass.h @@ -109,7 +109,7 @@ class MLUPostprocessPass : public ProgramPass { void RecreateOp(Node* inst_node, SSAGraph* graph); - void GatherFirstConvNodes(SSAGraph* graph); + void GatherAndModifyFirstConvNodes(SSAGraph* graph); bool IsFirstConvNode(Node* arg_node); diff --git a/lite/kernels/mlu/bridges/utility.h b/lite/kernels/mlu/bridges/utility.h index 2af8274e07..fa8fb1597c 100644 --- a/lite/kernels/mlu/bridges/utility.h +++ b/lite/kernels/mlu/bridges/utility.h @@ -84,7 +84,7 @@ struct FPTypeTraits { template <> struct FPTypeTraits { - typedef ::paddle::lite::fluid::float16 T; + typedef paddle::lite::fluid::float16 T; }; } // namespace mlu diff --git a/lite/kernels/mlu/layout_compute.cc b/lite/kernels/mlu/layout_compute.cc index 97d8202553..af0306d72d 100644 --- a/lite/kernels/mlu/layout_compute.cc +++ b/lite/kernels/mlu/layout_compute.cc @@ -89,3 +89,20 @@ REGISTER_LITE_KERNEL( PRECISION(kFloat), DATALAYOUT(kNHWC))}) .Finalize(); + +REGISTER_LITE_KERNEL( + layout, + kMLU, + kInt8, + kNHWC, + paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute, + def_layout_nchw2nhwc_fp32_int8) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kInt8), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .Finalize(); diff --git a/lite/kernels/mlu/layout_compute.h b/lite/kernels/mlu/layout_compute.h index 5e87e35264..edacdf8a98 100644 --- a/lite/kernels/mlu/layout_compute.h +++ b/lite/kernels/mlu/layout_compute.h @@ -29,6 +29,24 @@ namespace lite { namespace kernels { namespace mlu { +template +struct FPTypeTraits {}; + +template <> +struct FPTypeTraits { + typedef float T; +}; + +template <> +struct FPTypeTraits { + typedef paddle::lite::fluid::float16 T; +}; + +template <> +struct FPTypeTraits { + typedef int8_t T; +}; + template inline void LayoutTransCompute(const int dim, const lite::Context& context, @@ -63,7 +81,7 @@ class LayoutNchwToNhwcCompute auto& param = this->template Param(); auto* x = param.x; auto* out = param.y; - out->template mutable_data(); + out->template mutable_data::T>(); auto x_dims = param.x->dims().size(); auto& context = this->ctx_->template As(); @@ -88,7 +106,8 @@ class LayoutNchwToNhwcCompute CHECK(0) << "Unsupport dim in mlu layout nchw to nhwc"; } - LayoutTransCompute( + LayoutTransCompute::T>( x_dims, context, *x, out, axis); if (x_dims > 2) { @@ -111,7 +130,7 @@ class LayoutNhwcToNchwCompute auto& param = this->template Param(); auto* x = param.x; auto* out = param.y; - out->template mutable_data(); + out->template mutable_data::T>(); auto x_dims = param.x->dims().size(); auto& context = this->ctx_->template As(); @@ -136,7 +155,8 @@ class LayoutNhwcToNchwCompute CHECK(0) << "Unsupport dim in mlu layout nhwc to nchw"; } - LayoutTransCompute( + LayoutTransCompute::T>( x_dims, context, *x, out, axis); if (x_dims > 2) { -- GitLab