未验证 提交 3a387df6 编写于 作者: W Wilber 提交者: GitHub

[Inference] inference add cinn interface (#48741)

上级 379216ae
...@@ -484,7 +484,8 @@ void AnalyseClusterVariables( ...@@ -484,7 +484,8 @@ void AnalyseClusterVariables(
const std::unordered_set<std::string>& deny_var_set, const std::unordered_set<std::string>& deny_var_set,
GraphNodeSet* cluster_inputs, GraphNodeSet* cluster_inputs,
GraphNodeSet* cluster_outputs, GraphNodeSet* cluster_outputs,
GraphNodeSet* cluster_internals) { GraphNodeSet* cluster_internals,
bool is_inference_stage) {
// collecting all input and output of op // collecting all input and output of op
for (auto* op_node : cluster) { for (auto* op_node : cluster) {
const auto& op_name = op_node->Name(); const auto& op_name = op_node->Name();
...@@ -523,6 +524,18 @@ void AnalyseClusterVariables( ...@@ -523,6 +524,18 @@ void AnalyseClusterVariables(
for (auto* var_node : *cluster_internals) { for (auto* var_node : *cluster_internals) {
cluster_outputs->erase(var_node); cluster_outputs->erase(var_node);
} }
if (is_inference_stage) {
// If part of the output of the Op is not used by other operators, change it
// to internal. such as transpose2 op's XShape out.
auto outs = *cluster_outputs;
for (auto* node : outs) {
if (node->outputs.empty()) {
cluster_outputs->erase(node);
cluster_internals->insert(node);
}
}
}
} }
void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs, void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs,
...@@ -611,7 +624,7 @@ void ReplaceSubGraphWithCinnOpNode( ...@@ -611,7 +624,7 @@ void ReplaceSubGraphWithCinnOpNode(
// Here we using SubgraphDetector to detecte the subgraph that // Here we using SubgraphDetector to detecte the subgraph that
// all of op node supported by CINN. We using OpMapperRegistry // all of op node supported by CINN. We using OpMapperRegistry
// to check whether the op node supported by CINN. // to check whether the op node supported by CINN.
void SearchAllSubgraphs(Graph* graph) { void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
OpTransInfo trans_info; OpTransInfo trans_info;
...@@ -671,7 +684,8 @@ void SearchAllSubgraphs(Graph* graph) { ...@@ -671,7 +684,8 @@ void SearchAllSubgraphs(Graph* graph) {
deny_var_set, deny_var_set,
&cluster_inputs, &cluster_inputs,
&cluster_outputs, &cluster_outputs,
&cluster_internals); &cluster_internals,
is_inference_stage);
VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set); VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set);
VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs); VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs);
...@@ -698,7 +712,13 @@ void SearchAllSubgraphs(Graph* graph) { ...@@ -698,7 +712,13 @@ void SearchAllSubgraphs(Graph* graph) {
} }
} // namespace } // namespace
void BuildCinnPass::ApplyImpl(Graph* graph) const { SearchAllSubgraphs(graph); } void BuildCinnPass::ApplyImpl(Graph* graph) const {
bool is_inference_stage{false};
if (Has("is_inference_stage")) {
is_inference_stage = Get<bool>("is_inference_stage");
}
SearchAllSubgraphs(graph, is_inference_stage);
}
} // namespace paddle2cinn } // namespace paddle2cinn
} // namespace framework } // namespace framework
......
...@@ -368,6 +368,9 @@ struct Argument { ...@@ -368,6 +368,9 @@ struct Argument {
DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool); DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int); DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
// cinn compiler related
DECL_ARGUMENT_FIELD(use_cinn_compiler, UseCinnCompiler, bool);
private: private:
std::unordered_set<std::string> valid_fields_; std::unordered_set<std::string> valid_fields_;
}; };
......
...@@ -235,6 +235,8 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -235,6 +235,8 @@ void IRPassManager::CreatePasses(Argument *argument,
new framework::ProgramDesc *(&argument->main_program())); new framework::ProgramDesc *(&argument->main_program()));
} else if (pass_name == "memory_optimize_pass") { } else if (pass_name == "memory_optimize_pass") {
pass->Set("root_predictor_id", new int(argument->root_predictor_id())); pass->Set("root_predictor_id", new int(argument->root_predictor_id()));
} else if (pass_name == "build_cinn_pass") {
pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler()));
} }
if (pass_name == "lite_subgraph_pass") { if (pass_name == "lite_subgraph_pass") {
bool lite_enable_int8 = bool lite_enable_int8 =
......
...@@ -477,6 +477,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -477,6 +477,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// profile related. // profile related.
CP_MEMBER(with_profile_); CP_MEMBER(with_profile_);
// cinn compiler related.
CP_MEMBER(use_cinn_compiler_);
// glog related. // glog related.
CP_MEMBER(with_glog_info_); CP_MEMBER(with_glog_info_);
...@@ -542,7 +545,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -542,7 +545,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
#undef CP_MEMBER #undef CP_MEMBER
Update(); Update();
if (use_tensorrt_) { if (use_tensorrt_ || use_cinn_compiler_) {
// Update() will reset all the passes, when some tensorRT pass is deleted in // Update() will reset all the passes, when some tensorRT pass is deleted in
// other.pass_builder(), it will set again, so we just remove the // other.pass_builder(), it will set again, so we just remove the
// deleted_pass. // deleted_pass.
...@@ -872,6 +875,14 @@ void AnalysisConfig::Update() { ...@@ -872,6 +875,14 @@ void AnalysisConfig::Update() {
} }
} }
// TODO(wilber): An ugly method to update pass, need to be fixed.
if (use_cinn_compiler_) {
pass_builder()->ClearPasses();
for (const auto &pass : kCINNCompilerPasses) {
pass_builder()->AppendPass(pass);
}
}
if (use_dlnne_) { if (use_dlnne_) {
pass_builder()->ClearPasses(); pass_builder()->ClearPasses();
for (const auto &pass : kDlnneSubgraphPasses) { for (const auto &pass : kDlnneSubgraphPasses) {
...@@ -1316,6 +1327,9 @@ std::string AnalysisConfig::Summary() { ...@@ -1316,6 +1327,9 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_lite", use_lite_ ? "true" : "false"}); os.InsertRow({"use_lite", use_lite_ ? "true" : "false"});
} }
// cinn compiler
os.InsertRow({"use_cinn_compiler", use_cinn_compiler_ ? "true" : "false"});
// ir info // ir info
os.InsertRow({"ir_optim", enable_ir_optim_ ? "true" : "false"}); os.InsertRow({"ir_optim", enable_ir_optim_ ? "true" : "false"});
os.InsertRow({"ir_debug", ir_debug_ ? "true" : "false"}); os.InsertRow({"ir_debug", ir_debug_ ? "true" : "false"});
...@@ -1429,4 +1443,19 @@ void AnalysisConfig::Exp_DisableMixedInferOps( ...@@ -1429,4 +1443,19 @@ void AnalysisConfig::Exp_DisableMixedInferOps(
mixed_black_list_ = black_list; mixed_black_list_ = black_list;
} }
void AnalysisConfig::Exp_EnableCINNCompiler() {
#ifdef PADDLE_WITH_CINN
use_cinn_compiler_ = true;
Update();
#else
PADDLE_THROW(platform::errors::Unavailable(
"You tried to use CINN compiler, but Paddle was not compiled "
"with CINN."));
#endif
}
bool AnalysisConfig::cinn_compiler_enabled() const {
return use_cinn_compiler_;
}
} // namespace paddle } // namespace paddle
...@@ -1217,6 +1217,10 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1217,6 +1217,10 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_); argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
} }
if (config_.use_cinn_compiler_) {
argument_.SetUseCinnCompiler(config_.use_cinn_compiler_);
}
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (config_.mkldnn_quantizer_enabled()) { if (config_.mkldnn_quantizer_enabled()) {
LOG(INFO) << "Quantization is enabled"; LOG(INFO) << "Quantization is enabled";
...@@ -1239,21 +1243,25 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1239,21 +1243,25 @@ void AnalysisPredictor::PrepareArgument() {
#endif #endif
auto *pass_builder = config_.pass_builder(); auto *pass_builder = config_.pass_builder();
// TODO(inference): Need to reconstruct the pass_builder, pass should be
// processed in a single
if (model_precision_ != phi::DataType::FLOAT32) { if (model_precision_ != phi::DataType::FLOAT32) {
LOG(INFO) << "Model is mixed precision type with " << model_precision_ LOG(INFO) << "Model is mixed precision type with " << model_precision_
<< ", we will use a new PassStrategy. Note that only the GPU " << ", we will use a new PassStrategy. Note that only the GPU "
"backend is supported for now."; "backend is supported for now.";
pass_builder->ClearPasses(); if (!config_.use_cinn_compiler_) {
const auto &deleted_passes = pass_builder->GetAllDeletedPasses(); pass_builder->ClearPasses();
if (config_.tensorrt_engine_enabled()) { const auto &deleted_passes = pass_builder->GetAllDeletedPasses();
for (const auto &pass : kTrtLowerPrecisionPasses) { if (config_.tensorrt_engine_enabled()) {
if (deleted_passes.count(pass)) continue; for (const auto &pass : kTrtLowerPrecisionPasses) {
pass_builder->AppendPass(pass); if (deleted_passes.count(pass)) continue;
} pass_builder->AppendPass(pass);
} else if (config_.use_gpu()) { }
for (const auto &pass : kGpuLowerPrecisionPasses) { } else if (config_.use_gpu()) {
if (deleted_passes.count(pass)) continue; for (const auto &pass : kGpuLowerPrecisionPasses) {
pass_builder->AppendPass(pass); if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
}
} }
} }
} }
......
...@@ -1016,6 +1016,19 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1016,6 +1016,19 @@ struct PD_INFER_DECL AnalysisConfig {
void SetSkipLoadParams(bool value) { skip_load_params_ = value; } void SetSkipLoadParams(bool value) { skip_load_params_ = value; }
///
/// \brief Enable use cinn compiler optimization.
///
void Exp_EnableCINNCompiler();
///
/// \brief A boolean state telling whether the CINN compiler optimization is
/// turned on.
///
/// \return bool Whether the CINN compiler optimization is turned on.
///
bool cinn_compiler_enabled() const;
protected: protected:
// Update the config. // Update the config.
void Update(); void Update();
...@@ -1143,6 +1156,9 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1143,6 +1156,9 @@ struct PD_INFER_DECL AnalysisConfig {
Precision lite_precision_mode_; Precision lite_precision_mode_;
bool lite_zero_copy_; bool lite_zero_copy_;
// CINN compiler related.
bool use_cinn_compiler_{false};
// XPU related. // XPU related.
bool use_xpu_{false}; bool use_xpu_{false};
int xpu_device_id_{0}; int xpu_device_id_{0};
......
...@@ -204,6 +204,13 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{ ...@@ -204,6 +204,13 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{
"tensorrt_subgraph_pass", "tensorrt_subgraph_pass",
}; };
const std::vector<std::string> kCINNCompilerPasses{
"gpu_cpu_map_matmul_v2_to_mul_pass",
"gpu_cpu_map_matmul_v2_to_matmul_pass",
"gpu_cpu_map_matmul_to_mul_pass",
"build_cinn_pass",
};
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
// "identity_scale_op_clean_pass", // // "identity_scale_op_clean_pass", //
......
...@@ -349,6 +349,9 @@ PD_INFER_DECL extern const std::vector<std::string> kDlnneSubgraphPasses; ...@@ -349,6 +349,9 @@ PD_INFER_DECL extern const std::vector<std::string> kDlnneSubgraphPasses;
/// \brief List of lite subgraph passes. /// \brief List of lite subgraph passes.
PD_INFER_DECL extern const std::vector<std::string> kLiteSubgraphPasses; PD_INFER_DECL extern const std::vector<std::string> kLiteSubgraphPasses;
/// \brief List of cinn compiler passes.
PD_INFER_DECL extern const std::vector<std::string> kCINNCompilerPasses;
/// \brief TODO(inference): Most of the existing pass fusion operators do not /// \brief TODO(inference): Most of the existing pass fusion operators do not
/// support fp16/bf16 precision, temporarily use low precision pass to prevent /// support fp16/bf16 precision, temporarily use low precision pass to prevent
/// running errors. After fusion operator supports low precision, delete this. /// running errors. After fusion operator supports low precision, delete this.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册