未验证 提交 92179da1 编写于 作者: 石晓伟 提交者: GitHub

refactor: BindTargets and ExcludeTargets, test=develop (#2321)

上级 d045f646
......@@ -122,5 +122,37 @@ const std::string& DataLayoutRepr(DataLayoutType layout) {
return datalayout2string[x];
}
std::set<TargetType> ExpandValidTargets(TargetType target) {
static const std::set<TargetType> valid_set({TARGET(kX86),
TARGET(kCUDA),
TARGET(kARM),
TARGET(kOpenCL),
TARGET(kNPU),
TARGET(kXPU),
TARGET(kFPGA)});
if (target == TARGET(kAny)) {
return valid_set;
}
return std::set<TargetType>({target});
}
std::set<PrecisionType> ExpandValidPrecisions(PrecisionType precision) {
static const std::set<PrecisionType> valid_set(
{PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)});
if (precision == PRECISION(kAny)) {
return valid_set;
}
return std::set<PrecisionType>({precision});
}
std::set<DataLayoutType> ExpandValidLayouts(DataLayoutType layout) {
static const std::set<DataLayoutType> valid_set(
{DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)});
if (layout == DATALAYOUT(kAny)) {
return valid_set;
}
return std::set<DataLayoutType>({layout});
}
} // namespace lite_api
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <set>
#include <string>
// Generic helper definitions for shared library support
......@@ -125,6 +126,17 @@ const std::string& PrecisionRepr(PrecisionType precision);
const std::string& DataLayoutRepr(DataLayoutType layout);
// Get a set of all the elements represented by the target.
std::set<TargetType> ExpandValidTargets(TargetType target = TARGET(kAny));
// Get a set of all the elements represented by the precision.
std::set<PrecisionType> ExpandValidPrecisions(
PrecisionType precision = PRECISION(kAny));
// Get a set of all the elements represented by the layout.
std::set<DataLayoutType> ExpandValidLayouts(
DataLayoutType layout = DATALAYOUT(kAny));
/*
* Place specifies the execution context of a Kernel or input/output for a
* kernel. It is used to make the analysis of the MIR more clear and accurate.
......
......@@ -44,4 +44,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kX86)});
......@@ -49,10 +49,35 @@ class Pass {
// Some passes only apply to qualified targets, which need to be explicitly
// declared.
// Bind the target. At runtime, there must be one device in the bound targets.
// Bind targets. At runtime, there must be one device in the bound targets.
void BindTargets(const std::set<TargetType>& targets) {
bound_targets_ = targets;
std::set<TargetType> res;
for (const auto& target : targets) {
const std::set<TargetType>& universe = ExpandValidTargets(target);
std::set_union(bound_targets_.begin(),
bound_targets_.end(),
universe.begin(),
universe.end(),
std::inserter(res, res.begin()));
}
bound_targets_ = res;
}
// Exclude targets. At runtime, there must be one device in the bound targets.
void ExcludeTargets(const std::set<TargetType>& targets) {
std::set<TargetType> res;
for (const auto& target : targets) {
const std::set<TargetType>& universe = ExpandValidTargets(target);
std::set_difference(bound_targets_.begin(),
bound_targets_.end(),
universe.begin(),
universe.end(),
std::inserter(res, res.begin()));
}
bound_targets_ = res;
}
// Get all bound targets.
const std::set<TargetType>& Targets() const { return bound_targets_; }
......
......@@ -34,6 +34,10 @@ class PassRegistry {
pass_->BindTargets(targets);
return *this;
}
PassRegistry& ExcludeTargets(const std::set<TargetType>& targets) {
pass_->ExcludeTargets(targets);
return *this;
}
PassRegistry& BindKernel(const std::string& name,
const lite_api::Place& place) {
pass_->BindKernel(name, place);
......
......@@ -23,53 +23,17 @@ namespace lite {
using lite_api::Place;
namespace {
template <typename T>
class Types final {
public:
explicit Types(const std::set<T>& types) : types_(types) {}
~Types() = default;
std::set<T> ValidSet(const T& element) const;
private:
const std::set<T> types_;
};
template <typename T>
std::set<T> Types<T>::ValidSet(const T& element) const {
if (element == T::kAny) {
return types_;
} else if (element == T::kUnk) {
LOG(FATAL) << "The type of the kernel's place is unknown.";
}
return std::set<T>({element});
}
void ExpandPlaces(std::set<Place>* places, const Place& place) {
static const Types<TargetType> target_set({TARGET(kHost),
TARGET(kX86),
TARGET(kCUDA),
TARGET(kARM),
TARGET(kOpenCL),
TARGET(kNPU),
TARGET(kXPU),
TARGET(kFPGA)});
static const Types<PrecisionType> precision_set(
{PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)});
static const Types<DataLayoutType> layout_set(
{DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)});
for (const auto& target : target_set.ValidSet(place.target)) {
for (const auto& precision : precision_set.ValidSet(place.precision)) {
for (const auto& layout : layout_set.ValidSet(place.layout)) {
for (const auto& target : lite_api::ExpandValidTargets(place.target)) {
for (const auto& precision :
lite_api::ExpandValidPrecisions(place.precision)) {
for (const auto& layout : lite_api::ExpandValidLayouts(place.layout)) {
places->insert(Place(target, precision, layout));
}
}
}
}
} // anonymous namespace
bool KernelRegistered(const std::string name, const Place& place) {
std::set<Place> places;
ExpandPlaces(&places, place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册