未验证 提交 74878eaf 编写于 作者: L liu zhengxi 提交者: GitHub

[cherry-pick] refactor: BindTargets and ExcludeTargets #2329 (#2332)

* refactor: BindTargets and ExcludeTargets, test=release/v2.0.0

* delete xpu TARGET, test=release/v2.0.0
上级 d886718b
...@@ -113,5 +113,36 @@ const std::string& DataLayoutRepr(DataLayoutType layout) { ...@@ -113,5 +113,36 @@ const std::string& DataLayoutRepr(DataLayoutType layout) {
return datalayout2string[x]; return datalayout2string[x];
} }
std::set<TargetType> ExpandValidTargets(TargetType target) {
static const std::set<TargetType> valid_set({TARGET(kX86),
TARGET(kCUDA),
TARGET(kARM),
TARGET(kOpenCL),
TARGET(kNPU),
TARGET(kFPGA)});
if (target == TARGET(kAny)) {
return valid_set;
}
return std::set<TargetType>({target});
}
std::set<PrecisionType> ExpandValidPrecisions(PrecisionType precision) {
static const std::set<PrecisionType> valid_set(
{PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)});
if (precision == PRECISION(kAny)) {
return valid_set;
}
return std::set<PrecisionType>({precision});
}
std::set<DataLayoutType> ExpandValidLayouts(DataLayoutType layout) {
static const std::set<DataLayoutType> valid_set(
{DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)});
if (layout == DATALAYOUT(kAny)) {
return valid_set;
}
return std::set<DataLayoutType>({layout});
}
} // namespace lite_api } // namespace lite_api
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <set>
#include <string> #include <string>
// Generic helper definitions for shared library support // Generic helper definitions for shared library support
...@@ -124,6 +125,17 @@ const std::string& PrecisionRepr(PrecisionType precision); ...@@ -124,6 +125,17 @@ const std::string& PrecisionRepr(PrecisionType precision);
const std::string& DataLayoutRepr(DataLayoutType layout); const std::string& DataLayoutRepr(DataLayoutType layout);
// Get a set of all the elements represented by the target.
std::set<TargetType> ExpandValidTargets(TargetType target = TARGET(kAny));
// Get a set of all the elements represented by the precision.
std::set<PrecisionType> ExpandValidPrecisions(
PrecisionType precision = PRECISION(kAny));
// Get a set of all the elements represented by the layout.
std::set<DataLayoutType> ExpandValidLayouts(
DataLayoutType layout = DATALAYOUT(kAny));
/* /*
* Place specifies the execution context of a Kernel or input/output for a * Place specifies the execution context of a Kernel or input/output for a
* kernel. It is used to make the analysis of the MIR more clear and accurate. * kernel. It is used to make the analysis of the MIR more clear and accurate.
......
...@@ -44,4 +44,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -44,4 +44,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kX86)});
...@@ -49,10 +49,35 @@ class Pass { ...@@ -49,10 +49,35 @@ class Pass {
// Some passes only apply to qualified targets, which need to be explicitly // Some passes only apply to qualified targets, which need to be explicitly
// declared. // declared.
// Bind the target. At runtime, there must be one device in the bound targets.
// Bind targets. At runtime, there must be one device in the bound targets.
void BindTargets(const std::set<TargetType>& targets) { void BindTargets(const std::set<TargetType>& targets) {
bound_targets_ = targets; std::set<TargetType> res;
for (const auto& target : targets) {
const std::set<TargetType>& universe = ExpandValidTargets(target);
std::set_union(bound_targets_.begin(),
bound_targets_.end(),
universe.begin(),
universe.end(),
std::inserter(res, res.begin()));
}
bound_targets_ = res;
}
// Exclude targets. At runtime, there must be one device in the bound targets.
void ExcludeTargets(const std::set<TargetType>& targets) {
std::set<TargetType> res;
for (const auto& target : targets) {
const std::set<TargetType>& universe = ExpandValidTargets(target);
std::set_difference(bound_targets_.begin(),
bound_targets_.end(),
universe.begin(),
universe.end(),
std::inserter(res, res.begin()));
}
bound_targets_ = res;
} }
// Get all bound targets. // Get all bound targets.
const std::set<TargetType>& Targets() const { return bound_targets_; } const std::set<TargetType>& Targets() const { return bound_targets_; }
......
...@@ -34,6 +34,10 @@ class PassRegistry { ...@@ -34,6 +34,10 @@ class PassRegistry {
pass_->BindTargets(targets); pass_->BindTargets(targets);
return *this; return *this;
} }
PassRegistry& ExcludeTargets(const std::set<TargetType>& targets) {
pass_->ExcludeTargets(targets);
return *this;
}
PassRegistry& BindKernel(const std::string& name, PassRegistry& BindKernel(const std::string& name,
const lite_api::Place& place) { const lite_api::Place& place) {
pass_->BindKernel(name, place); pass_->BindKernel(name, place);
......
...@@ -23,52 +23,17 @@ namespace lite { ...@@ -23,52 +23,17 @@ namespace lite {
using lite_api::Place; using lite_api::Place;
namespace {
template <typename T>
class Types final {
public:
explicit Types(const std::set<T>& types) : types_(types) {}
~Types() = default;
std::set<T> ValidSet(const T& element) const;
private:
const std::set<T> types_;
};
template <typename T>
std::set<T> Types<T>::ValidSet(const T& element) const {
if (element == T::kAny) {
return types_;
} else if (element == T::kUnk) {
LOG(FATAL) << "The type of the kernel's place is unknown.";
}
return std::set<T>({element});
}
void ExpandPlaces(std::set<Place>* places, const Place& place) { void ExpandPlaces(std::set<Place>* places, const Place& place) {
static const Types<TargetType> target_set({TARGET(kHost), for (const auto& target : lite_api::ExpandValidTargets(place.target)) {
TARGET(kX86), for (const auto& precision :
TARGET(kCUDA), lite_api::ExpandValidPrecisions(place.precision)) {
TARGET(kARM), for (const auto& layout : lite_api::ExpandValidLayouts(place.layout)) {
TARGET(kOpenCL),
TARGET(kNPU),
TARGET(kFPGA)});
static const Types<PrecisionType> precision_set(
{PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)});
static const Types<DataLayoutType> layout_set(
{DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)});
for (const auto& target : target_set.ValidSet(place.target)) {
for (const auto& precision : precision_set.ValidSet(place.precision)) {
for (const auto& layout : layout_set.ValidSet(place.layout)) {
places->insert(Place(target, precision, layout)); places->insert(Place(target, precision, layout));
} }
} }
} }
} }
} // anonymous namespace
bool KernelRegistered(const std::string name, const Place& place) { bool KernelRegistered(const std::string name, const Place& place) {
std::set<Place> places; std::set<Place> places;
ExpandPlaces(&places, place); ExpandPlaces(&places, place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册