From dbc8f893d76cb37e6028770201706a88293a36f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Wed, 18 Sep 2019 20:47:49 +0800 Subject: [PATCH] modify the device binding logic of the pass, test=develop (#2060) --- .../mir/fusion/conv_activation_fuse_pass.cc | 3 +- lite/core/mir/fusion/conv_bn_fuse_pass.cc | 3 +- .../elementwise_add_activation_fuse_pass.cc | 3 +- lite/core/mir/fusion/fc_fuse_pass.cc | 3 +- .../mir/fusion/shuffle_channel_fuse_pass.cc | 3 +- lite/core/mir/io_copy_kernel_pick_pass.cc | 3 +- lite/core/mir/pass_registry.h | 5 ++ lite/core/mir/pass_utils.cc | 67 ++++++++++++++++++- lite/core/mir/pass_utils.h | 4 ++ .../mir/subgraph/generate_npu_program_pass.cc | 2 +- lite/core/mir/type_layout_cast_pass.cc | 4 +- lite/core/mir/type_precision_cast_pass.cc | 4 +- lite/core/mir/type_target_cast_pass.cc | 4 +- lite/core/op_registry.h | 10 ++- lite/core/optimizer.h | 4 +- 15 files changed, 104 insertions(+), 18 deletions(-) diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index ceb3b0ea34..7ced84e8f5 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -39,4 +39,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_conv_activation_fuse_pass, paddle::lite::mir::ConvActivationFusePass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .BindKernel("conv2d"); diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index 8ac2dd252e..d7e274b146 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -35,4 +35,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .BindKernel("elementwise_add"); diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc index 67e9e56fcf..af66f5ab66 100644 --- a/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc @@ -34,4 +34,5 @@ void ElementwiseAddActivationFusePass::Apply( REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, paddle::lite::mir::ElementwiseAddActivationFusePass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .BindKernel("fusion_elementwise_add_activation"); diff --git a/lite/core/mir/fusion/fc_fuse_pass.cc b/lite/core/mir/fusion/fc_fuse_pass.cc index 380f8f932d..ed10f06f56 100644 --- a/lite/core/mir/fusion/fc_fuse_pass.cc +++ b/lite/core/mir/fusion/fc_fuse_pass.cc @@ -32,4 +32,5 @@ void FcFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .BindKernel("fc"); diff --git a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc index 01b18a1842..2c289da82c 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc +++ b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc @@ -36,4 +36,5 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass, paddle::lite::mir::ShuffleChannelFusePass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .BindKernel("shuffle_channel"); diff --git a/lite/core/mir/io_copy_kernel_pick_pass.cc b/lite/core/mir/io_copy_kernel_pick_pass.cc index 90cf3559e3..df5ddffe8a 100644 --- a/lite/core/mir/io_copy_kernel_pick_pass.cc +++ b/lite/core/mir/io_copy_kernel_pick_pass.cc @@ -72,4 +72,5 @@ class IoCopyKernelPickPass : public StmtPass { REGISTER_MIR_PASS(io_copy_kernel_pick_pass, paddle::lite::mir::IoCopyKernelPickPass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .BindKernel("io_copy"); diff --git a/lite/core/mir/pass_registry.h b/lite/core/mir/pass_registry.h index 89a4b3efd6..d8418f5c40 100644 --- a/lite/core/mir/pass_registry.h +++ b/lite/core/mir/pass_registry.h @@ -39,6 +39,11 @@ class PassRegistry { pass_->BindKernel(name, place); return *this; } + PassRegistry& BindKernel(const std::string& name) { + pass_->BindKernel(name, + Place(TARGET(kAny), PRECISION(kAny), DATALAYOUT(kAny))); + return *this; + } bool Touch() const { return true; } private: diff --git a/lite/core/mir/pass_utils.cc b/lite/core/mir/pass_utils.cc index f15a7d713c..b67f5e4bd1 100644 --- a/lite/core/mir/pass_utils.cc +++ b/lite/core/mir/pass_utils.cc @@ -16,10 +16,72 @@ #include #include #include +#include "lite/core/op_registry.h" namespace paddle { namespace lite { +using lite_api::Place; + +namespace { + +template +class Types final { + public: + explicit Types(const std::set& types) : types_(types) {} + ~Types() = default; + std::set ValidSet(const T& element) const; + + private: + const std::set types_; +}; + +template +std::set Types::ValidSet(const T& element) const { + if (element == T::kAny) { + return types_; + } else if (element == T::kUnk) { + LOG(FATAL) << "The type of the kernel's place is unknown."; + } + return std::set({element}); +} + +bool ExpandPlaces(std::set* places, const Place& place) { + static const Types target_set({TARGET(kHost), + TARGET(kX86), + TARGET(kCUDA), + TARGET(kARM), + TARGET(kOpenCL), + TARGET(kNPU), + TARGET(kFPGA)}); + static const Types precision_set( + {PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)}); + static const Types layout_set( + {DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)}); + for (const auto& target : target_set.ValidSet(place.target)) { + for (const auto& precision : precision_set.ValidSet(place.precision)) { + for (const auto& layout : layout_set.ValidSet(place.layout)) { + places->insert(Place(target, precision, layout)); + } + } + } +} + +} // anonymous namespace + +bool KernelRegistered(const std::string name, const Place& place) { + std::set places; + ExpandPlaces(&places, place); + for (const auto& p : places) { + if (!KernelRegistry::Global() + .Create(name, p.target, p.precision, p.layout) + .empty()) { + return true; + } + } + return false; +} + bool PassMatchesTarget(const mir::Pass& pass, TargetType target) { const auto& targets = pass.Targets(); if (targets.find(TARGET(kAny)) != targets.end()) return true; @@ -30,10 +92,9 @@ bool PassMatchesKernels(const mir::Pass& pass) { const auto& kernels = pass.GetBoundKernels(); for (const auto& kernel : kernels) { for (const auto& place : kernel.second) { - if (KernelRegistry::Global() - .Create(kernel.first, place.target, place.precision, place.layout) - .empty()) + if (!KernelRegistered(kernel.first, place)) { return false; + } } } return true; diff --git a/lite/core/mir/pass_utils.h b/lite/core/mir/pass_utils.h index 445c91fe77..942f64bf31 100644 --- a/lite/core/mir/pass_utils.h +++ b/lite/core/mir/pass_utils.h @@ -14,11 +14,15 @@ #pragma once +#include #include "lite/core/mir/pass.h" namespace paddle { namespace lite { +// Query if the specified kernel has been registered. +bool KernelRegistered(const std::string name, const Place& place); + // Check if the pass hits the hardware target. bool PassMatchesTarget(const mir::Pass& pass, TargetType target); diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.cc b/lite/core/mir/subgraph/generate_npu_program_pass.cc index 8badd357c3..27e7997372 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass.cc +++ b/lite/core/mir/subgraph/generate_npu_program_pass.cc @@ -215,4 +215,4 @@ std::unique_ptr GenerateNPUProgramPass::GenProgram() { REGISTER_MIR_PASS(generate_npu_program_pass, paddle::lite::mir::subgraph::GenerateNPUProgramPass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kNPU)}); diff --git a/lite/core/mir/type_layout_cast_pass.cc b/lite/core/mir/type_layout_cast_pass.cc index 11f4a21f24..57523a0274 100644 --- a/lite/core/mir/type_layout_cast_pass.cc +++ b/lite/core/mir/type_layout_cast_pass.cc @@ -174,4 +174,6 @@ void TypeLayoutTransformPass::SetValidPlaces( REGISTER_MIR_PASS(type_layout_cast_pass, paddle::lite::mir::TypeLayoutTransformPass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .BindKernel("layout_once") + .BindKernel("layout"); diff --git a/lite/core/mir/type_precision_cast_pass.cc b/lite/core/mir/type_precision_cast_pass.cc index 5a99a67255..5a61f5fc66 100644 --- a/lite/core/mir/type_precision_cast_pass.cc +++ b/lite/core/mir/type_precision_cast_pass.cc @@ -180,4 +180,6 @@ void PrecisionCastPass::SetValidPlaces(const std::vector& valid_places) { REGISTER_MIR_PASS(type_precision_cast_pass, paddle::lite::mir::PrecisionCastPass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .BindKernel("calib_once") + .BindKernel("calib"); diff --git a/lite/core/mir/type_target_cast_pass.cc b/lite/core/mir/type_target_cast_pass.cc index 0af7fa3cfd..0141a488e4 100644 --- a/lite/core/mir/type_target_cast_pass.cc +++ b/lite/core/mir/type_target_cast_pass.cc @@ -180,4 +180,6 @@ void TypeTargetTransformPass::SetValidPlaces( REGISTER_MIR_PASS(type_target_cast_pass, paddle::lite::mir::TypeTargetTransformPass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .BindKernel("io_copy_once") + .BindKernel("io_copy"); diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index 5b48c251c8..d2c0637b0b 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -174,9 +174,13 @@ class KernelRegistry final { std::list> Create(const std::string &op_type) { using kernel_registor_t = KernelRegistryForTarget; - return registries_[GetKernelOffset()] - .template get() - ->Creates(op_type); + std::list> kernel_list; + if (registries_[GetKernelOffset()].valid()) { + kernel_list = registries_[GetKernelOffset()] + .template get() + ->Creates(op_type); + } + return kernel_list; } std::list> Create(const std::string &op_type, diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 7361eed236..5b6a32447c 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -193,9 +193,9 @@ class Optimizer { matched = true; } } - matched = matched || PassMatchesKernels(*pass); + matched = matched && PassMatchesKernels(*pass); if (!matched) { - LOG(INFO) << "Skip " << x << " pass because the target does not match."; + LOG(INFO) << " - Skip " << x << " because the target does not match."; } else { pass->Apply(graph_); LOG(INFO) << "== Finished running: " << x; -- GitLab