From 92179da1b9ad5e66ea539413108c413f8f468190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Fri, 1 Nov 2019 10:06:51 +0800 Subject: [PATCH] refactor: BindTargets and ExcludeTargets, test=develop (#2321) --- lite/api/paddle_place.cc | 32 +++++++++++++++++ lite/api/paddle_place.h | 12 +++++++ lite/core/mir/fusion/conv_bn_fuse_pass.cc | 3 +- lite/core/mir/pass.h | 29 +++++++++++++-- lite/core/mir/pass_registry.h | 4 +++ lite/core/mir/pass_utils.cc | 44 +++-------------------- 6 files changed, 81 insertions(+), 43 deletions(-) diff --git a/lite/api/paddle_place.cc b/lite/api/paddle_place.cc index ccacb027d6..894d839185 100644 --- a/lite/api/paddle_place.cc +++ b/lite/api/paddle_place.cc @@ -122,5 +122,37 @@ const std::string& DataLayoutRepr(DataLayoutType layout) { return datalayout2string[x]; } +std::set ExpandValidTargets(TargetType target) { + static const std::set valid_set({TARGET(kX86), + TARGET(kCUDA), + TARGET(kARM), + TARGET(kOpenCL), + TARGET(kNPU), + TARGET(kXPU), + TARGET(kFPGA)}); + if (target == TARGET(kAny)) { + return valid_set; + } + return std::set({target}); +} + +std::set ExpandValidPrecisions(PrecisionType precision) { + static const std::set valid_set( + {PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)}); + if (precision == PRECISION(kAny)) { + return valid_set; + } + return std::set({precision}); +} + +std::set ExpandValidLayouts(DataLayoutType layout) { + static const std::set valid_set( + {DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)}); + if (layout == DATALAYOUT(kAny)) { + return valid_set; + } + return std::set({layout}); +} + } // namespace lite_api } // namespace paddle diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index 19ec5c6e8b..259887e2fb 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include // Generic helper definitions for shared library support @@ -125,6 +126,17 @@ const std::string& PrecisionRepr(PrecisionType precision); const std::string& DataLayoutRepr(DataLayoutType layout); +// Get a set of all the elements represented by the target. +std::set ExpandValidTargets(TargetType target = TARGET(kAny)); + +// Get a set of all the elements represented by the precision. +std::set ExpandValidPrecisions( + PrecisionType precision = PRECISION(kAny)); + +// Get a set of all the elements represented by the layout. +std::set ExpandValidLayouts( + DataLayoutType layout = DATALAYOUT(kAny)); + /* * Place specifies the execution context of a Kernel or input/output for a * kernel. It is used to make the analysis of the MIR more clear and accurate. diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index 25c8a21725..d9d9c1bbf5 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -44,4 +44,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .ExcludeTargets({TARGET(kX86)}); diff --git a/lite/core/mir/pass.h b/lite/core/mir/pass.h index 8fd12fafa3..4de0fdbf35 100644 --- a/lite/core/mir/pass.h +++ b/lite/core/mir/pass.h @@ -49,10 +49,35 @@ class Pass { // Some passes only apply to qualified targets, which need to be explicitly // declared. - // Bind the target. At runtime, there must be one device in the bound targets. + + // Bind targets. At runtime, there must be one device in the bound targets. void BindTargets(const std::set& targets) { - bound_targets_ = targets; + std::set res; + for (const auto& target : targets) { + const std::set& universe = ExpandValidTargets(target); + std::set_union(bound_targets_.begin(), + bound_targets_.end(), + universe.begin(), + universe.end(), + std::inserter(res, res.begin())); + } + bound_targets_ = res; + } + + // Exclude targets. At runtime, there must be one device in the bound targets. + void ExcludeTargets(const std::set& targets) { + std::set res; + for (const auto& target : targets) { + const std::set& universe = ExpandValidTargets(target); + std::set_difference(bound_targets_.begin(), + bound_targets_.end(), + universe.begin(), + universe.end(), + std::inserter(res, res.begin())); + } + bound_targets_ = res; } + // Get all bound targets. const std::set& Targets() const { return bound_targets_; } diff --git a/lite/core/mir/pass_registry.h b/lite/core/mir/pass_registry.h index d8418f5c40..849f80aea2 100644 --- a/lite/core/mir/pass_registry.h +++ b/lite/core/mir/pass_registry.h @@ -34,6 +34,10 @@ class PassRegistry { pass_->BindTargets(targets); return *this; } + PassRegistry& ExcludeTargets(const std::set& targets) { + pass_->ExcludeTargets(targets); + return *this; + } PassRegistry& BindKernel(const std::string& name, const lite_api::Place& place) { pass_->BindKernel(name, place); diff --git a/lite/core/mir/pass_utils.cc b/lite/core/mir/pass_utils.cc index cfa43f8d6e..4f6be2c186 100644 --- a/lite/core/mir/pass_utils.cc +++ b/lite/core/mir/pass_utils.cc @@ -23,53 +23,17 @@ namespace lite { using lite_api::Place; -namespace { - -template -class Types final { - public: - explicit Types(const std::set& types) : types_(types) {} - ~Types() = default; - std::set ValidSet(const T& element) const; - - private: - const std::set types_; -}; - -template -std::set Types::ValidSet(const T& element) const { - if (element == T::kAny) { - return types_; - } else if (element == T::kUnk) { - LOG(FATAL) << "The type of the kernel's place is unknown."; - } - return std::set({element}); -} - void ExpandPlaces(std::set* places, const Place& place) { - static const Types target_set({TARGET(kHost), - TARGET(kX86), - TARGET(kCUDA), - TARGET(kARM), - TARGET(kOpenCL), - TARGET(kNPU), - TARGET(kXPU), - TARGET(kFPGA)}); - static const Types precision_set( - {PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)}); - static const Types layout_set( - {DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)}); - for (const auto& target : target_set.ValidSet(place.target)) { - for (const auto& precision : precision_set.ValidSet(place.precision)) { - for (const auto& layout : layout_set.ValidSet(place.layout)) { + for (const auto& target : lite_api::ExpandValidTargets(place.target)) { + for (const auto& precision : + lite_api::ExpandValidPrecisions(place.precision)) { + for (const auto& layout : lite_api::ExpandValidLayouts(place.layout)) { places->insert(Place(target, precision, layout)); } } } } -} // anonymous namespace - bool KernelRegistered(const std::string name, const Place& place) { std::set places; ExpandPlaces(&places, place); -- GitLab