diff --git a/lite/api/paddle_place.cc b/lite/api/paddle_place.cc index dbdf9ff269b372cd3dcd59769b15526b7631a5e5..0cf9885e8839b44267ec43a4b0a465bc1a4834dc 100644 --- a/lite/api/paddle_place.cc +++ b/lite/api/paddle_place.cc @@ -113,5 +113,36 @@ const std::string& DataLayoutRepr(DataLayoutType layout) { return datalayout2string[x]; } +std::set ExpandValidTargets(TargetType target) { + static const std::set valid_set({TARGET(kX86), + TARGET(kCUDA), + TARGET(kARM), + TARGET(kOpenCL), + TARGET(kNPU), + TARGET(kFPGA)}); + if (target == TARGET(kAny)) { + return valid_set; + } + return std::set({target}); +} + +std::set ExpandValidPrecisions(PrecisionType precision) { + static const std::set valid_set( + {PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)}); + if (precision == PRECISION(kAny)) { + return valid_set; + } + return std::set({precision}); +} + +std::set ExpandValidLayouts(DataLayoutType layout) { + static const std::set valid_set( + {DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)}); + if (layout == DATALAYOUT(kAny)) { + return valid_set; + } + return std::set({layout}); +} + } // namespace lite_api } // namespace paddle diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index 5e4f2ed21c8298ac15a912672e3d15633d0a3ecb..b9842cd355e7f273d3f7bd0470ff507b847825ad 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include // Generic helper definitions for shared library support @@ -124,6 +125,17 @@ const std::string& PrecisionRepr(PrecisionType precision); const std::string& DataLayoutRepr(DataLayoutType layout); +// Get a set of all the elements represented by the target. +std::set ExpandValidTargets(TargetType target = TARGET(kAny)); + +// Get a set of all the elements represented by the precision. +std::set ExpandValidPrecisions( + PrecisionType precision = PRECISION(kAny)); + +// Get a set of all the elements represented by the layout. +std::set ExpandValidLayouts( + DataLayoutType layout = DATALAYOUT(kAny)); + /* * Place specifies the execution context of a Kernel or input/output for a * kernel. It is used to make the analysis of the MIR more clear and accurate. diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index 25c8a217251d25f0f7b4a37c4c656c535810b76e..d9d9c1bbf55bd33c31aa9a22de934d4eae8657c6 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -44,4 +44,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) - .BindTargets({TARGET(kAny)}); + .BindTargets({TARGET(kAny)}) + .ExcludeTargets({TARGET(kX86)}); diff --git a/lite/core/mir/pass.h b/lite/core/mir/pass.h index 8fd12fafa3fd6183eb3bba894be04d96075f1bc3..4de0fdbf357160348a403d3c8527fe62891237f0 100644 --- a/lite/core/mir/pass.h +++ b/lite/core/mir/pass.h @@ -49,10 +49,35 @@ class Pass { // Some passes only apply to qualified targets, which need to be explicitly // declared. - // Bind the target. At runtime, there must be one device in the bound targets. + + // Bind targets. At runtime, there must be one device in the bound targets. void BindTargets(const std::set& targets) { - bound_targets_ = targets; + std::set res; + for (const auto& target : targets) { + const std::set& universe = ExpandValidTargets(target); + std::set_union(bound_targets_.begin(), + bound_targets_.end(), + universe.begin(), + universe.end(), + std::inserter(res, res.begin())); + } + bound_targets_ = res; + } + + // Exclude targets. At runtime, there must be one device in the bound targets. + void ExcludeTargets(const std::set& targets) { + std::set res; + for (const auto& target : targets) { + const std::set& universe = ExpandValidTargets(target); + std::set_difference(bound_targets_.begin(), + bound_targets_.end(), + universe.begin(), + universe.end(), + std::inserter(res, res.begin())); + } + bound_targets_ = res; } + // Get all bound targets. const std::set& Targets() const { return bound_targets_; } diff --git a/lite/core/mir/pass_registry.h b/lite/core/mir/pass_registry.h index d8418f5c400921430b2f42d92173a6e02a95eb75..849f80aea2191b72ac423c7125a4e69cb6927be5 100644 --- a/lite/core/mir/pass_registry.h +++ b/lite/core/mir/pass_registry.h @@ -34,6 +34,10 @@ class PassRegistry { pass_->BindTargets(targets); return *this; } + PassRegistry& ExcludeTargets(const std::set& targets) { + pass_->ExcludeTargets(targets); + return *this; + } PassRegistry& BindKernel(const std::string& name, const lite_api::Place& place) { pass_->BindKernel(name, place); diff --git a/lite/core/mir/pass_utils.cc b/lite/core/mir/pass_utils.cc index 804d4e1b5bc94f0e7804fa588e107a298210143b..4f6be2c186d2d940a799201812cce397a9e94eb4 100644 --- a/lite/core/mir/pass_utils.cc +++ b/lite/core/mir/pass_utils.cc @@ -23,52 +23,17 @@ namespace lite { using lite_api::Place; -namespace { - -template -class Types final { - public: - explicit Types(const std::set& types) : types_(types) {} - ~Types() = default; - std::set ValidSet(const T& element) const; - - private: - const std::set types_; -}; - -template -std::set Types::ValidSet(const T& element) const { - if (element == T::kAny) { - return types_; - } else if (element == T::kUnk) { - LOG(FATAL) << "The type of the kernel's place is unknown."; - } - return std::set({element}); -} - void ExpandPlaces(std::set* places, const Place& place) { - static const Types target_set({TARGET(kHost), - TARGET(kX86), - TARGET(kCUDA), - TARGET(kARM), - TARGET(kOpenCL), - TARGET(kNPU), - TARGET(kFPGA)}); - static const Types precision_set( - {PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)}); - static const Types layout_set( - {DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)}); - for (const auto& target : target_set.ValidSet(place.target)) { - for (const auto& precision : precision_set.ValidSet(place.precision)) { - for (const auto& layout : layout_set.ValidSet(place.layout)) { + for (const auto& target : lite_api::ExpandValidTargets(place.target)) { + for (const auto& precision : + lite_api::ExpandValidPrecisions(place.precision)) { + for (const auto& layout : lite_api::ExpandValidLayouts(place.layout)) { places->insert(Place(target, precision, layout)); } } } } -} // anonymous namespace - bool KernelRegistered(const std::string name, const Place& place) { std::set places; ExpandPlaces(&places, place);