提交 dbc8f893 编写于 作者: 石晓伟 提交者: GitHub

modify the device binding logic of the pass, test=develop (#2060)

上级 3682a9df
......@@ -39,4 +39,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_activation_fuse_pass,
paddle::lite::mir::ConvActivationFusePass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)})
.BindKernel("conv2d");
......@@ -35,4 +35,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)})
.BindKernel("elementwise_add");
......@@ -34,4 +34,5 @@ void ElementwiseAddActivationFusePass::Apply(
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)})
.BindKernel("fusion_elementwise_add_activation");
......@@ -32,4 +32,5 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)})
.BindKernel("fc");
......@@ -36,4 +36,5 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass,
paddle::lite::mir::ShuffleChannelFusePass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)})
.BindKernel("shuffle_channel");
......@@ -72,4 +72,5 @@ class IoCopyKernelPickPass : public StmtPass {
REGISTER_MIR_PASS(io_copy_kernel_pick_pass,
paddle::lite::mir::IoCopyKernelPickPass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)})
.BindKernel("io_copy");
......@@ -39,6 +39,11 @@ class PassRegistry {
pass_->BindKernel(name, place);
return *this;
}
PassRegistry& BindKernel(const std::string& name) {
pass_->BindKernel(name,
Place(TARGET(kAny), PRECISION(kAny), DATALAYOUT(kAny)));
return *this;
}
bool Touch() const { return true; }
private:
......
......@@ -16,10 +16,72 @@
#include <set>
#include <string>
#include <unordered_map>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
using lite_api::Place;
namespace {
template <typename T>
class Types final {
public:
explicit Types(const std::set<T>& types) : types_(types) {}
~Types() = default;
std::set<T> ValidSet(const T& element) const;
private:
const std::set<T> types_;
};
template <typename T>
std::set<T> Types<T>::ValidSet(const T& element) const {
if (element == T::kAny) {
return types_;
} else if (element == T::kUnk) {
LOG(FATAL) << "The type of the kernel's place is unknown.";
}
return std::set<T>({element});
}
bool ExpandPlaces(std::set<Place>* places, const Place& place) {
static const Types<TargetType> target_set({TARGET(kHost),
TARGET(kX86),
TARGET(kCUDA),
TARGET(kARM),
TARGET(kOpenCL),
TARGET(kNPU),
TARGET(kFPGA)});
static const Types<PrecisionType> precision_set(
{PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)});
static const Types<DataLayoutType> layout_set(
{DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)});
for (const auto& target : target_set.ValidSet(place.target)) {
for (const auto& precision : precision_set.ValidSet(place.precision)) {
for (const auto& layout : layout_set.ValidSet(place.layout)) {
places->insert(Place(target, precision, layout));
}
}
}
}
} // anonymous namespace
bool KernelRegistered(const std::string name, const Place& place) {
std::set<Place> places;
ExpandPlaces(&places, place);
for (const auto& p : places) {
if (!KernelRegistry::Global()
.Create(name, p.target, p.precision, p.layout)
.empty()) {
return true;
}
}
return false;
}
bool PassMatchesTarget(const mir::Pass& pass, TargetType target) {
const auto& targets = pass.Targets();
if (targets.find(TARGET(kAny)) != targets.end()) return true;
......@@ -30,10 +92,9 @@ bool PassMatchesKernels(const mir::Pass& pass) {
const auto& kernels = pass.GetBoundKernels();
for (const auto& kernel : kernels) {
for (const auto& place : kernel.second) {
if (KernelRegistry::Global()
.Create(kernel.first, place.target, place.precision, place.layout)
.empty())
if (!KernelRegistered(kernel.first, place)) {
return false;
}
}
}
return true;
......
......@@ -14,11 +14,15 @@
#pragma once
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
// Query if the specified kernel has been registered.
bool KernelRegistered(const std::string name, const Place& place);
// Check if the pass hits the hardware target.
bool PassMatchesTarget(const mir::Pass& pass, TargetType target);
......
......@@ -215,4 +215,4 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
REGISTER_MIR_PASS(generate_npu_program_pass,
paddle::lite::mir::subgraph::GenerateNPUProgramPass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kNPU)});
......@@ -174,4 +174,6 @@ void TypeLayoutTransformPass::SetValidPlaces(
REGISTER_MIR_PASS(type_layout_cast_pass,
paddle::lite::mir::TypeLayoutTransformPass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)})
.BindKernel("layout_once")
.BindKernel("layout");
......@@ -180,4 +180,6 @@ void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) {
REGISTER_MIR_PASS(type_precision_cast_pass,
paddle::lite::mir::PrecisionCastPass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)})
.BindKernel("calib_once")
.BindKernel("calib");
......@@ -180,4 +180,6 @@ void TypeTargetTransformPass::SetValidPlaces(
REGISTER_MIR_PASS(type_target_cast_pass,
paddle::lite::mir::TypeTargetTransformPass)
.BindTargets({TARGET(kAny)});
.BindTargets({TARGET(kAny)})
.BindKernel("io_copy_once")
.BindKernel("io_copy");
......@@ -174,9 +174,13 @@ class KernelRegistry final {
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) {
using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>;
return registries_[GetKernelOffset<Target, Precision, Layout>()]
.template get<kernel_registor_t *>()
->Creates(op_type);
std::list<std::unique_ptr<KernelBase>> kernel_list;
if (registries_[GetKernelOffset<Target, Precision, Layout>()].valid()) {
kernel_list = registries_[GetKernelOffset<Target, Precision, Layout>()]
.template get<kernel_registor_t *>()
->Creates(op_type);
}
return kernel_list;
}
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type,
......
......@@ -193,9 +193,9 @@ class Optimizer {
matched = true;
}
}
matched = matched || PassMatchesKernels(*pass);
matched = matched && PassMatchesKernels(*pass);
if (!matched) {
LOG(INFO) << "Skip " << x << " pass because the target does not match.";
LOG(INFO) << " - Skip " << x << " because the target does not match.";
} else {
pass->Apply(graph_);
LOG(INFO) << "== Finished running: " << x;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册