未验证 提交 8ca10db8 编写于 作者: 石晓伟 提交者: GitHub

make passes related to the device type, test=develop (#2012)

* make passes related to the device type, test=develop

* improve tips, test=develop
上级 13bbd2b8
......@@ -17,6 +17,7 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cudnn.h>
#include <limits>
#include <string>
namespace paddle {
......@@ -39,8 +40,8 @@ __device__ __forceinline__ half from_float<half>(float x) {
template <>
__device__ __forceinline__ int8_t from_float<int8_t>(float x) {
x = fmaxf(x, INT8_MIN);
x = fminf(x, INT8_MAX);
x = fmaxf(x, std::numeric_limits<char>::min());
x = fminf(x, std::numeric_limits<char>::max());
return __float2int_rn(x);
}
......
......@@ -42,4 +42,5 @@ class ArgumentTypeDisplayPass : public DebugPass {
} // namespace paddle
REGISTER_MIR_PASS(argument_type_display_pass,
paddle::lite::mir::ArgumentTypeDisplayPass);
paddle::lite::mir::ArgumentTypeDisplayPass)
.SetTargets({TARGET(kAny)});
......@@ -34,4 +34,4 @@ bool RegisterDemoPass() {
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass);
REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass).SetTargets({TARGET(kAny)});
......@@ -69,4 +69,5 @@ class IdentityScaleEliminatePass : public ProgramPass {
} // namespace paddle
REGISTER_MIR_PASS(identity_scale_eliminate_pass,
paddle::lite::mir::IdentityScaleEliminatePass);
paddle::lite::mir::IdentityScaleEliminatePass)
.SetTargets({TARGET(kAny)});
......@@ -38,4 +38,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_activation_fuse_pass,
paddle::lite::mir::ConvActivationFusePass);
paddle::lite::mir::ConvActivationFusePass)
.SetTargets({TARGET(kAny)});
......@@ -34,4 +34,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass);
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.SetTargets({TARGET(kAny)});
......@@ -38,4 +38,5 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass,
paddle::lite::mir::ConvElementwiseFusePass);
paddle::lite::mir::ConvElementwiseFusePass)
.SetTargets({TARGET(kAny)});
......@@ -33,4 +33,5 @@ void ElementwiseAddActivationFusePass::Apply(
} // namespace paddle
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass);
paddle::lite::mir::ElementwiseAddActivationFusePass)
.SetTargets({TARGET(kAny)});
......@@ -31,4 +31,5 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass);
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.SetTargets({TARGET(kAny)});
......@@ -35,4 +35,5 @@ void InterpolateFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(lite_interpolate_fuse_pass,
paddle::lite::mir::InterpolateFusePass);
paddle::lite::mir::InterpolateFusePass)
.SetTargets({TARGET(kAny)});
......@@ -15,6 +15,7 @@
#include "lite/core/mir/fusion/quant_dequant_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include "lite/core/mir/pass_registry.h"
......@@ -42,4 +43,5 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass,
paddle::lite::mir::QuantDequantFusePass);
paddle::lite::mir::QuantDequantFusePass)
.SetTargets({TARGET(kAny)});
......@@ -35,4 +35,5 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass,
paddle::lite::mir::ShuffleChannelFusePass);
paddle::lite::mir::ShuffleChannelFusePass)
.SetTargets({TARGET(kAny)});
......@@ -36,4 +36,5 @@ void TransposeSoftmaxTransposeFusePass::Apply(
} // namespace paddle
REGISTER_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass,
paddle::lite::mir::TransposeSoftmaxTransposeFusePass);
paddle::lite::mir::TransposeSoftmaxTransposeFusePass)
.SetTargets({TARGET(kAny)});
......@@ -38,5 +38,5 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(generate_program_pass,
paddle::lite::mir::GenerateProgramPass);
REGISTER_MIR_PASS(generate_program_pass, paddle::lite::mir::GenerateProgramPass)
.SetTargets({TARGET(kAny)});
......@@ -98,4 +98,5 @@ std::string Visualize(mir::SSAGraph* graph) {
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass);
REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass)
.SetTargets({TARGET(kAny)});
......@@ -71,4 +71,5 @@ class IoCopyKernelPickPass : public StmtPass {
} // namespace paddle
REGISTER_MIR_PASS(io_copy_kernel_pick_pass,
paddle::lite::mir::IoCopyKernelPickPass);
paddle::lite::mir::IoCopyKernelPickPass)
.SetTargets({TARGET(kAny)});
......@@ -14,7 +14,9 @@
#pragma once
#include <memory>
#include <set>
#include <string>
#include "lite/core/mir/node.h"
#include "lite/core/mir/ssa_graph.h"
......@@ -44,6 +46,13 @@ class Pass {
void set_doc(const std::string& doc) { doc_ = doc; }
const std::string& doc() const { return doc_; }
void set_targets(const std::set<TargetType>& targets) { targets_ = targets; }
const std::set<TargetType>& targets() const { return targets_; }
bool is_supported_target(TargetType target) const {
if (targets_.find(TARGET(kAny)) != targets_.end()) return true;
return (targets_.find(target) != targets_.end());
}
Kind kind() const { return kind_; }
bool is_debug_pass() const { return kind_ == Kind::kDebug; }
bool is_program_pass() const { return kind_ == Kind::kProgramWise; }
......@@ -55,6 +64,7 @@ class Pass {
const Kind kind_;
std::string name_;
std::string doc_;
std::set<TargetType> targets_;
};
// Different kinds.
......
......@@ -14,8 +14,10 @@
#pragma once
#include <set>
#include <string>
#include "lite/api/paddle_lite_factory_helper.h"
#include "lite/api/paddle_place.h"
#include "lite/core/mir/pass_manager.h"
namespace paddle {
......@@ -24,12 +26,19 @@ namespace mir {
class PassRegistry {
public:
PassRegistry(const std::string& name, mir::Pass* pass) {
VLOG(2) << "Registry add MIR pass " << name;
PassManager::Global().AddNewPass(name, pass);
PassRegistry(const std::string& name, mir::Pass* pass)
: name_(name), pass_(pass) {
PassManager::Global().AddNewPass(name_, pass_);
}
PassRegistry& SetTargets(const std::set<TargetType>& targets) {
pass_->set_targets(targets);
return *this;
}
bool Touch() const { return true; }
private:
std::string name_;
mir::Pass* pass_;
};
} // namespace mir
......@@ -41,4 +50,6 @@ class PassRegistry {
new class__); \
bool mir_pass_registry##name__##_fake() { \
return mir_pass_registry##name__.Touch(); \
}
} \
static paddle::lite::mir::PassRegistry mir_pass_registry_func_##name__ \
__attribute__((unused)) = mir_pass_registry##name__
......@@ -38,4 +38,5 @@ class RuntimeContextAssignPass : public StmtPass {
} // namespace paddle
REGISTER_MIR_PASS(runtime_context_assign_pass,
paddle::lite::mir::RuntimeContextAssignPass);
paddle::lite::mir::RuntimeContextAssignPass)
.SetTargets({TARGET(kAny)});
......@@ -132,4 +132,5 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle
REGISTER_MIR_PASS(static_kernel_pick_pass,
paddle::lite::mir::StaticKernelPickPass);
paddle::lite::mir::StaticKernelPickPass)
.SetTargets({TARGET(kAny)});
......@@ -214,4 +214,5 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
} // namespace paddle
REGISTER_MIR_PASS(generate_npu_program_pass,
paddle::lite::mir::subgraph::GenerateNPUProgramPass);
paddle::lite::mir::subgraph::GenerateNPUProgramPass)
.SetTargets({TARGET(kAny)});
......@@ -310,4 +310,5 @@ int SubgraphProgramPass::FuseSubgraph(
} // namespace paddle
REGISTER_MIR_PASS(subgraph_program_pass,
paddle::lite::mir::subgraph::SubgraphProgramPass);
paddle::lite::mir::subgraph::SubgraphProgramPass)
.SetTargets({TARGET(kAny)});
......@@ -173,4 +173,5 @@ void TypeLayoutTransformPass::SetValidPlaces(
} // namespace paddle
REGISTER_MIR_PASS(type_layout_cast_pass,
paddle::lite::mir::TypeLayoutTransformPass);
paddle::lite::mir::TypeLayoutTransformPass)
.SetTargets({TARGET(kAny)});
......@@ -179,4 +179,5 @@ void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) {
} // namespace paddle
REGISTER_MIR_PASS(type_precision_cast_pass,
paddle::lite::mir::PrecisionCastPass);
paddle::lite::mir::PrecisionCastPass)
.SetTargets({TARGET(kAny)});
......@@ -179,4 +179,5 @@ void TypeTargetTransformPass::SetValidPlaces(
} // namespace paddle
REGISTER_MIR_PASS(type_target_cast_pass,
paddle::lite::mir::TypeTargetTransformPass);
paddle::lite::mir::TypeTargetTransformPass)
.SetTargets({TARGET(kAny)});
......@@ -31,4 +31,5 @@ void VariablePlaceInferencePass::Apply(const std::unique_ptr<SSAGraph> &graph) {
} // namespace paddle
REGISTER_MIR_PASS(variable_place_inference_pass,
paddle::lite::mir::VariablePlaceInferencePass);
paddle::lite::mir::VariablePlaceInferencePass)
.SetTargets({TARGET(kAny)});
......@@ -153,9 +153,6 @@ class KernelRegistry final {
const std::string &name,
typename KernelRegistryForTarget<Target, Precision, Layout>::creator_t
&&creator) {
VLOG(3) << "register for " << TargetToStr(Target) << ":"
<< PrecisionToStr(Precision) << "//"
<< GetKernelOffset<Target, Precision, Layout>();
using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>;
auto &varient = registries_[GetKernelOffset<Target, Precision, Layout>()];
......@@ -219,9 +216,6 @@ class KernelRegistor : public lite::Registor<KernelType> {
public:
KernelRegistor(const std::string &op_type, const std::string &alias)
: Registor<KernelType>([=] {
VLOG(3) << "Register kernel " << op_type << " for "
<< TargetToStr(target) << " " << PrecisionToStr(precision)
<< " " << DataLayoutToStr(layout) << " alias " << alias;
KernelRegistry::Global().Register<target, precision, layout>(
op_type, [=]() -> std::unique_ptr<KernelType> {
std::unique_ptr<KernelType> x(new KernelType);
......
......@@ -183,11 +183,22 @@ class Optimizer {
// Specify the passes and run them.
void RunPasses(const std::vector<std::string>& passes) {
for (auto& x : passes) {
LOG(INFO) << "== Running pass " << x;
auto* pass = mir::PassManager::Global().LookUp(x);
LOG(INFO) << "== Running pass: " << x;
mir::Pass* pass = mir::PassManager::Global().LookUp(x);
CHECK(pass) << "Can not find pass: " << x;
pass->Apply(graph_);
LOG(INFO) << "== Running pass Done." << x;
bool supported = false;
for (const auto& place : valid_places_) {
if (pass->is_supported_target(place.target)) {
supported = true;
}
}
if (!supported) {
LOG(WARNING) << "Skip " << x
<< " pass because the target does not match.";
} else {
pass->Apply(graph_);
LOG(INFO) << "== Finished running: " << x;
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册