diff --git a/lite/api/paddle_place.cc b/lite/api/paddle_place.cc index c11a72ec6cbe1e427ad71bcceec3b141158479c6..dbdf9ff269b372cd3dcd59769b15526b7631a5e5 100644 --- a/lite/api/paddle_place.cc +++ b/lite/api/paddle_place.cc @@ -113,37 +113,5 @@ const std::string& DataLayoutRepr(DataLayoutType layout) { return datalayout2string[x]; } -std::set ExpandValidTargets(TargetType target) { - static const std::set valid_set({TARGET(kX86), - TARGET(kCUDA), - TARGET(kARM), - TARGET(kOpenCL), - TARGET(kNPU), - TARGET(kXPU), - TARGET(kFPGA)}); - if (target == TARGET(kAny)) { - return valid_set; - } - return std::set({target}); -} - -std::set ExpandValidPrecisions(PrecisionType precision) { - static const std::set valid_set( - {PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)}); - if (precision == PRECISION(kAny)) { - return valid_set; - } - return std::set({precision}); -} - -std::set ExpandValidLayouts(DataLayoutType layout) { - static const std::set valid_set( - {DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)}); - if (layout == DATALAYOUT(kAny)) { - return valid_set; - } - return std::set({layout}); -} - } // namespace lite_api } // namespace paddle diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index b9842cd355e7f273d3f7bd0470ff507b847825ad..5e4f2ed21c8298ac15a912672e3d15633d0a3ecb 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -13,7 +13,6 @@ // limitations under the License. #pragma once -#include #include // Generic helper definitions for shared library support @@ -125,17 +124,6 @@ const std::string& PrecisionRepr(PrecisionType precision); const std::string& DataLayoutRepr(DataLayoutType layout); -// Get a set of all the elements represented by the target. -std::set ExpandValidTargets(TargetType target = TARGET(kAny)); - -// Get a set of all the elements represented by the precision. -std::set ExpandValidPrecisions( - PrecisionType precision = PRECISION(kAny)); - -// Get a set of all the elements represented by the layout. -std::set ExpandValidLayouts( - DataLayoutType layout = DATALAYOUT(kAny)); - /* * Place specifies the execution context of a Kernel or input/output for a * kernel. It is used to make the analysis of the MIR more clear and accurate. diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index d9d9c1bbf55bd33c31aa9a22de934d4eae8657c6..25c8a217251d25f0f7b4a37c4c656c535810b76e 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -44,5 +44,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr& graph) { } // namespace paddle REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) - .BindTargets({TARGET(kAny)}) - .ExcludeTargets({TARGET(kX86)}); + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/pass.h b/lite/core/mir/pass.h index 4de0fdbf357160348a403d3c8527fe62891237f0..8fd12fafa3fd6183eb3bba894be04d96075f1bc3 100644 --- a/lite/core/mir/pass.h +++ b/lite/core/mir/pass.h @@ -49,35 +49,10 @@ class Pass { // Some passes only apply to qualified targets, which need to be explicitly // declared. - - // Bind targets. At runtime, there must be one device in the bound targets. + // Bind the target. At runtime, there must be one device in the bound targets. void BindTargets(const std::set& targets) { - std::set res; - for (const auto& target : targets) { - const std::set& universe = ExpandValidTargets(target); - std::set_union(bound_targets_.begin(), - bound_targets_.end(), - universe.begin(), - universe.end(), - std::inserter(res, res.begin())); - } - bound_targets_ = res; - } - - // Exclude targets. At runtime, there must be one device in the bound targets. - void ExcludeTargets(const std::set& targets) { - std::set res; - for (const auto& target : targets) { - const std::set& universe = ExpandValidTargets(target); - std::set_difference(bound_targets_.begin(), - bound_targets_.end(), - universe.begin(), - universe.end(), - std::inserter(res, res.begin())); - } - bound_targets_ = res; + bound_targets_ = targets; } - // Get all bound targets. const std::set& Targets() const { return bound_targets_; } diff --git a/lite/core/mir/pass_registry.h b/lite/core/mir/pass_registry.h index 849f80aea2191b72ac423c7125a4e69cb6927be5..d8418f5c400921430b2f42d92173a6e02a95eb75 100644 --- a/lite/core/mir/pass_registry.h +++ b/lite/core/mir/pass_registry.h @@ -34,10 +34,6 @@ class PassRegistry { pass_->BindTargets(targets); return *this; } - PassRegistry& ExcludeTargets(const std::set& targets) { - pass_->ExcludeTargets(targets); - return *this; - } PassRegistry& BindKernel(const std::string& name, const lite_api::Place& place) { pass_->BindKernel(name, place); diff --git a/lite/core/mir/pass_utils.cc b/lite/core/mir/pass_utils.cc index 4f6be2c186d2d940a799201812cce397a9e94eb4..804d4e1b5bc94f0e7804fa588e107a298210143b 100644 --- a/lite/core/mir/pass_utils.cc +++ b/lite/core/mir/pass_utils.cc @@ -23,17 +23,52 @@ namespace lite { using lite_api::Place; +namespace { + +template +class Types final { + public: + explicit Types(const std::set& types) : types_(types) {} + ~Types() = default; + std::set ValidSet(const T& element) const; + + private: + const std::set types_; +}; + +template +std::set Types::ValidSet(const T& element) const { + if (element == T::kAny) { + return types_; + } else if (element == T::kUnk) { + LOG(FATAL) << "The type of the kernel's place is unknown."; + } + return std::set({element}); +} + void ExpandPlaces(std::set* places, const Place& place) { - for (const auto& target : lite_api::ExpandValidTargets(place.target)) { - for (const auto& precision : - lite_api::ExpandValidPrecisions(place.precision)) { - for (const auto& layout : lite_api::ExpandValidLayouts(place.layout)) { + static const Types target_set({TARGET(kHost), + TARGET(kX86), + TARGET(kCUDA), + TARGET(kARM), + TARGET(kOpenCL), + TARGET(kNPU), + TARGET(kFPGA)}); + static const Types precision_set( + {PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)}); + static const Types layout_set( + {DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)}); + for (const auto& target : target_set.ValidSet(place.target)) { + for (const auto& precision : precision_set.ValidSet(place.precision)) { + for (const auto& layout : layout_set.ValidSet(place.layout)) { places->insert(Place(target, precision, layout)); } } } } +} // anonymous namespace + bool KernelRegistered(const std::string name, const Place& place) { std::set places; ExpandPlaces(&places, place);