提交 dbc8f893 编写于 作者: 石晓伟 提交者: GitHub

modify the device binding logic of the pass, test=develop (#2060)

上级 3682a9df
...@@ -39,4 +39,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -39,4 +39,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_activation_fuse_pass, REGISTER_MIR_PASS(lite_conv_activation_fuse_pass,
paddle::lite::mir::ConvActivationFusePass) paddle::lite::mir::ConvActivationFusePass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)})
.BindKernel("conv2d");
...@@ -35,4 +35,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -35,4 +35,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass) REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)})
.BindKernel("elementwise_add");
...@@ -34,4 +34,5 @@ void ElementwiseAddActivationFusePass::Apply( ...@@ -34,4 +34,5 @@ void ElementwiseAddActivationFusePass::Apply(
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass) paddle::lite::mir::ElementwiseAddActivationFusePass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)})
.BindKernel("fusion_elementwise_add_activation");
...@@ -32,4 +32,5 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -32,4 +32,5 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass) REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)})
.BindKernel("fc");
...@@ -36,4 +36,5 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -36,4 +36,5 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass, REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass,
paddle::lite::mir::ShuffleChannelFusePass) paddle::lite::mir::ShuffleChannelFusePass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)})
.BindKernel("shuffle_channel");
...@@ -72,4 +72,5 @@ class IoCopyKernelPickPass : public StmtPass { ...@@ -72,4 +72,5 @@ class IoCopyKernelPickPass : public StmtPass {
REGISTER_MIR_PASS(io_copy_kernel_pick_pass, REGISTER_MIR_PASS(io_copy_kernel_pick_pass,
paddle::lite::mir::IoCopyKernelPickPass) paddle::lite::mir::IoCopyKernelPickPass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)})
.BindKernel("io_copy");
...@@ -39,6 +39,11 @@ class PassRegistry { ...@@ -39,6 +39,11 @@ class PassRegistry {
pass_->BindKernel(name, place); pass_->BindKernel(name, place);
return *this; return *this;
} }
PassRegistry& BindKernel(const std::string& name) {
pass_->BindKernel(name,
Place(TARGET(kAny), PRECISION(kAny), DATALAYOUT(kAny)));
return *this;
}
bool Touch() const { return true; } bool Touch() const { return true; }
private: private:
......
...@@ -16,10 +16,72 @@ ...@@ -16,10 +16,72 @@
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
using lite_api::Place;
namespace {
template <typename T>
class Types final {
public:
explicit Types(const std::set<T>& types) : types_(types) {}
~Types() = default;
std::set<T> ValidSet(const T& element) const;
private:
const std::set<T> types_;
};
template <typename T>
std::set<T> Types<T>::ValidSet(const T& element) const {
if (element == T::kAny) {
return types_;
} else if (element == T::kUnk) {
LOG(FATAL) << "The type of the kernel's place is unknown.";
}
return std::set<T>({element});
}
bool ExpandPlaces(std::set<Place>* places, const Place& place) {
static const Types<TargetType> target_set({TARGET(kHost),
TARGET(kX86),
TARGET(kCUDA),
TARGET(kARM),
TARGET(kOpenCL),
TARGET(kNPU),
TARGET(kFPGA)});
static const Types<PrecisionType> precision_set(
{PRECISION(kFloat), PRECISION(kInt8), PRECISION(kFP16), PRECISION(kAny)});
static const Types<DataLayoutType> layout_set(
{DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)});
for (const auto& target : target_set.ValidSet(place.target)) {
for (const auto& precision : precision_set.ValidSet(place.precision)) {
for (const auto& layout : layout_set.ValidSet(place.layout)) {
places->insert(Place(target, precision, layout));
}
}
}
}
} // anonymous namespace
bool KernelRegistered(const std::string name, const Place& place) {
std::set<Place> places;
ExpandPlaces(&places, place);
for (const auto& p : places) {
if (!KernelRegistry::Global()
.Create(name, p.target, p.precision, p.layout)
.empty()) {
return true;
}
}
return false;
}
bool PassMatchesTarget(const mir::Pass& pass, TargetType target) { bool PassMatchesTarget(const mir::Pass& pass, TargetType target) {
const auto& targets = pass.Targets(); const auto& targets = pass.Targets();
if (targets.find(TARGET(kAny)) != targets.end()) return true; if (targets.find(TARGET(kAny)) != targets.end()) return true;
...@@ -30,10 +92,9 @@ bool PassMatchesKernels(const mir::Pass& pass) { ...@@ -30,10 +92,9 @@ bool PassMatchesKernels(const mir::Pass& pass) {
const auto& kernels = pass.GetBoundKernels(); const auto& kernels = pass.GetBoundKernels();
for (const auto& kernel : kernels) { for (const auto& kernel : kernels) {
for (const auto& place : kernel.second) { for (const auto& place : kernel.second) {
if (KernelRegistry::Global() if (!KernelRegistered(kernel.first, place)) {
.Create(kernel.first, place.target, place.precision, place.layout)
.empty())
return false; return false;
}
} }
} }
return true; return true;
......
...@@ -14,11 +14,15 @@ ...@@ -14,11 +14,15 @@
#pragma once #pragma once
#include <string>
#include "lite/core/mir/pass.h" #include "lite/core/mir/pass.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
// Query if the specified kernel has been registered.
bool KernelRegistered(const std::string name, const Place& place);
// Check if the pass hits the hardware target. // Check if the pass hits the hardware target.
bool PassMatchesTarget(const mir::Pass& pass, TargetType target); bool PassMatchesTarget(const mir::Pass& pass, TargetType target);
......
...@@ -215,4 +215,4 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() { ...@@ -215,4 +215,4 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
REGISTER_MIR_PASS(generate_npu_program_pass, REGISTER_MIR_PASS(generate_npu_program_pass,
paddle::lite::mir::subgraph::GenerateNPUProgramPass) paddle::lite::mir::subgraph::GenerateNPUProgramPass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kNPU)});
...@@ -174,4 +174,6 @@ void TypeLayoutTransformPass::SetValidPlaces( ...@@ -174,4 +174,6 @@ void TypeLayoutTransformPass::SetValidPlaces(
REGISTER_MIR_PASS(type_layout_cast_pass, REGISTER_MIR_PASS(type_layout_cast_pass,
paddle::lite::mir::TypeLayoutTransformPass) paddle::lite::mir::TypeLayoutTransformPass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)})
.BindKernel("layout_once")
.BindKernel("layout");
...@@ -180,4 +180,6 @@ void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) { ...@@ -180,4 +180,6 @@ void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) {
REGISTER_MIR_PASS(type_precision_cast_pass, REGISTER_MIR_PASS(type_precision_cast_pass,
paddle::lite::mir::PrecisionCastPass) paddle::lite::mir::PrecisionCastPass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)})
.BindKernel("calib_once")
.BindKernel("calib");
...@@ -180,4 +180,6 @@ void TypeTargetTransformPass::SetValidPlaces( ...@@ -180,4 +180,6 @@ void TypeTargetTransformPass::SetValidPlaces(
REGISTER_MIR_PASS(type_target_cast_pass, REGISTER_MIR_PASS(type_target_cast_pass,
paddle::lite::mir::TypeTargetTransformPass) paddle::lite::mir::TypeTargetTransformPass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)})
.BindKernel("io_copy_once")
.BindKernel("io_copy");
...@@ -174,9 +174,13 @@ class KernelRegistry final { ...@@ -174,9 +174,13 @@ class KernelRegistry final {
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) { std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) {
using kernel_registor_t = using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>; KernelRegistryForTarget<Target, Precision, Layout>;
return registries_[GetKernelOffset<Target, Precision, Layout>()] std::list<std::unique_ptr<KernelBase>> kernel_list;
.template get<kernel_registor_t *>() if (registries_[GetKernelOffset<Target, Precision, Layout>()].valid()) {
->Creates(op_type); kernel_list = registries_[GetKernelOffset<Target, Precision, Layout>()]
.template get<kernel_registor_t *>()
->Creates(op_type);
}
return kernel_list;
} }
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type, std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type,
......
...@@ -193,9 +193,9 @@ class Optimizer { ...@@ -193,9 +193,9 @@ class Optimizer {
matched = true; matched = true;
} }
} }
matched = matched || PassMatchesKernels(*pass); matched = matched && PassMatchesKernels(*pass);
if (!matched) { if (!matched) {
LOG(INFO) << "Skip " << x << " pass because the target does not match."; LOG(INFO) << " - Skip " << x << " because the target does not match.";
} else { } else {
pass->Apply(graph_); pass->Apply(graph_);
LOG(INFO) << "== Finished running: " << x; LOG(INFO) << "== Finished running: " << x;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册