未验证 提交 97f43a8e 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT] constant-folding (#45494)

add constant folding pass, for some model,it will get less latency;
上级 9dad4f79
...@@ -147,6 +147,7 @@ pass_library(delete_dropout_op_pass inference) ...@@ -147,6 +147,7 @@ pass_library(delete_dropout_op_pass inference)
pass_library(delete_c_identity_op_pass inference) pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference) pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference) pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(simplify_with_basic_ops_pass base) pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/constant_folding_pass.h"
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/framework/convert_utils.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
/*
* When a op's inputs and outputs is determined before feeding data to the
* model, we can remove this op from the model. This ConstantFolding pass can
* remove all these like ops.
*
*/
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct ConstantFolding : public PatternBase {
ConstantFolding(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "constant_folding_pass") {}
};
} // namespace patterns
ConstantFoldingPass::ConstantFoldingPass() {}
void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("constant_folding", graph);
auto *scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"scope must not be null when applying constant floding."));
// Now, I don't want to fold fill_constant op in Paddle-TRT
std::vector<std::string> blacklist{"fill_constant", "feed"};
auto op_node_sorted = framework::ir::TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(0));
for (auto *op_node : op_node_sorted) {
if (!op_node->IsOp()) continue;
if (std::find(blacklist.begin(), blacklist.end(), op_node->Name()) !=
blacklist.end())
continue;
bool input_persis = true;
// map is used to record how many time a name string occures in the whole
// graph's nodes
std::map<std::string, int> map;
for (auto in_node : op_node->inputs) {
map[in_node->Name()] = 0;
if (!in_node->Var()->Persistable()) {
input_persis = false;
}
}
for (auto out_node : op_node->outputs) {
map[out_node->Name()] = 0;
}
// Forbid other node in graph having the same name with nodes in map
for (auto iter : map) {
for (auto node : graph->Nodes()) {
if (node->IsVar() && node->Name() == iter.first) {
map[node->Name()]++;
if (map[node->Name()] > 1) {
input_persis = false;
}
}
}
}
framework::Scope *local_scope = new framework::Scope();
std::unordered_set<const paddle::framework::ir::Node *> remove_nodes;
std::unique_ptr<OperatorBase> op;
if (input_persis) {
for (auto in_node : op_node->inputs) {
local_scope->Var(in_node->Var()->Name());
local_scope->FindVar(in_node->Var()->Name())->GetMutable<LoDTensor>();
// This persistable input node is exclusive, and can be removed
if (in_node->outputs.size() == 1L) remove_nodes.emplace(in_node);
auto in_shape = in_node->Var()->GetShape();
auto *global_persis_x_tensor =
scope->FindVar(in_node->Name())->GetMutable<LoDTensor>();
auto *local_x_tensor =
local_scope->FindVar(in_node->Name())->GetMutable<LoDTensor>();
local_x_tensor->Resize(global_persis_x_tensor->dims());
*local_x_tensor = *global_persis_x_tensor;
}
op = paddle::framework::OpRegistry::CreateOp(*op_node->Op());
remove_nodes.emplace(op_node);
for (auto out_node : op_node->outputs) {
local_scope->Var(out_node->Var()->Name());
local_scope->FindVar(out_node->Var()->Name())->GetMutable<LoDTensor>();
// useless out_node can be removed, not need set it persistable !
if (out_node->outputs.size() == 0L) remove_nodes.emplace(out_node);
}
op->Run(*local_scope, platform::CPUPlace());
for (auto out_node : op_node->outputs) {
// this out_node is useless, do not set it persistable
if (out_node->outputs.size() == 0L) continue;
auto out_desc = out_node->Var();
auto out_name = out_desc->Name();
auto *local_out_tensor =
local_scope->FindVar(out_name)->GetMutable<LoDTensor>();
std::vector<int64_t> out_shape;
for (int64_t i = 0; i < local_out_tensor->dims().size(); i++) {
out_shape.push_back(local_out_tensor->dims()[i]);
}
out_desc->SetShape(out_shape);
out_desc->SetPersistable(true);
auto *global_out_tensor = scope->Var(out_name)->GetMutable<LoDTensor>();
*global_out_tensor = *local_out_tensor;
}
GraphSafeRemoveNodes(graph, remove_nodes);
}
delete local_scope;
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(constant_folding_pass,
paddle::framework::ir::ConstantFoldingPass);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class ConstantFoldingPass : public FusePassBase {
public:
ConstantFoldingPass();
virtual ~ConstantFoldingPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -121,8 +121,9 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -121,8 +121,9 @@ const std::vector<std::string> kTRTSubgraphPasses({
// "yolo_box_fuse_pass", // // "yolo_box_fuse_pass", //
"dense_fc_to_sparse_pass", // "dense_fc_to_sparse_pass", //
"dense_multihead_matmul_to_sparse_pass", // "dense_multihead_matmul_to_sparse_pass", //
"tensorrt_subgraph_pass", // "constant_folding_pass",
"conv_bn_fuse_pass", // "tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7 // guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we // cudnn8.0 has memory leak problem in conv + eltwise + act, so we
...@@ -213,6 +214,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -213,6 +214,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
#endif // #endif //
"transpose_flatten_concat_fuse_pass", // "transpose_flatten_concat_fuse_pass", //
"constant_folding_pass",
// following pass should be located in the last, since it will // following pass should be located in the last, since it will
// work on all fused ops. // work on all fused ops.
"runtime_context_cache_pass" "runtime_context_cache_pass"
...@@ -276,6 +278,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { ...@@ -276,6 +278,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"conv_transpose_bn_fuse_pass", // "conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", // "conv_transpose_eltwiseadd_bn_fuse_pass", //
"is_test_pass", // "is_test_pass", //
"constant_folding_pass",
// following pass should be located in the last, since // following pass should be located in the last, since
// it will work on all fused ops. // it will work on all fused ops.
"runtime_context_cache_pass"}); "runtime_context_cache_pass"});
......
...@@ -169,9 +169,16 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, ...@@ -169,9 +169,16 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots,
input_slots->push_back(std::move(response_mask_tensor)); input_slots->push_back(std::move(response_mask_tensor));
} }
/*
* this model is unreasonable, it set a output tensor persistable, so
* ridiculous! so I disable constant_folding_pass
*/
void SetConfig(AnalysisConfig *cfg) { void SetConfig(AnalysisConfig *cfg) {
cfg->SetModel(FLAGS_infer_model + "/__model__", FLAGS_infer_model + "/param"); cfg->SetModel(FLAGS_infer_model + "/__model__", FLAGS_infer_model + "/param");
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
auto pass_builder = cfg->pass_builder();
pass_builder->DeletePass("constant_folding_pass");
cfg->SwitchIrOptim(true); cfg->SwitchIrOptim(true);
} }
......
...@@ -17,6 +17,11 @@ ...@@ -17,6 +17,11 @@
namespace paddle { namespace paddle {
namespace inference { namespace inference {
/*
* this model is unreasonable, it set a middle-tensor persistable, so
* ridiculous! so I disable constant_folding_pass
*/
using paddle::PaddleTensor; using paddle::PaddleTensor;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -25,6 +30,8 @@ void SetInt8Config(AnalysisConfig *cfg, ...@@ -25,6 +30,8 @@ void SetInt8Config(AnalysisConfig *cfg,
cfg->SetModel(FLAGS_infer_model); cfg->SetModel(FLAGS_infer_model);
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
cfg->EnableMkldnnQuantizer(); cfg->EnableMkldnnQuantizer();
auto pass_builder = cfg->pass_builder();
pass_builder->DeletePass("constant_folding_pass");
auto warmup_data = std::make_shared<std::vector<PaddleTensor>>(data); auto warmup_data = std::make_shared<std::vector<PaddleTensor>>(data);
cfg->mkldnn_quantizer_config()->SetWarmupData(warmup_data); cfg->mkldnn_quantizer_config()->SetWarmupData(warmup_data);
cfg->mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_batch_size); cfg->mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_batch_size);
......
...@@ -17,13 +17,19 @@ ...@@ -17,13 +17,19 @@
namespace paddle { namespace paddle {
namespace inference { namespace inference {
/*
* this model is unreasonable, it set a middle-tensor persistable, so
* ridiculous! so I disable constant_folding_pass
*/
using paddle::PaddleTensor; using paddle::PaddleTensor;
void profile(bool use_mkldnn = false, bool use_gpu = false) { void profile(bool use_mkldnn = false, bool use_gpu = false) {
AnalysisConfig config; AnalysisConfig config;
SetConfig(&config, use_mkldnn, use_gpu); SetConfig(&config, use_mkldnn, use_gpu);
auto pass_builder = config.pass_builder();
pass_builder->DeletePass("constant_folding_pass");
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> inputs; std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs); LoadInputData(&inputs);
...@@ -48,6 +54,9 @@ TEST(Analyzer_Ernie, fuse_statis) { ...@@ -48,6 +54,9 @@ TEST(Analyzer_Ernie, fuse_statis) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
auto pass_builder = cfg.pass_builder();
pass_builder->DeletePass("constant_folding_pass");
int num_ops; int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg); auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis( auto fuse_statis = GetFuseStatis(
...@@ -70,7 +79,8 @@ void compare(bool use_mkldnn = false) { ...@@ -70,7 +79,8 @@ void compare(bool use_mkldnn = false) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg, use_mkldnn, false); SetConfig(&cfg, use_mkldnn, false);
auto pass_builder = cfg.pass_builder();
pass_builder->DeletePass("constant_folding_pass");
CompareNativeAndAnalysis( CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), inputs); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), inputs);
} }
...@@ -84,7 +94,8 @@ TEST(Analyzer_ernie, compare_mkldnn) { compare(true /* use_mkldnn */); } ...@@ -84,7 +94,8 @@ TEST(Analyzer_ernie, compare_mkldnn) { compare(true /* use_mkldnn */); }
TEST(Analyzer_Ernie, compare_determine) { TEST(Analyzer_Ernie, compare_determine) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
auto pass_builder = cfg.pass_builder();
pass_builder->DeletePass("constant_folding_pass");
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all); LoadInputData(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg), CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
...@@ -95,7 +106,8 @@ TEST(Analyzer_Ernie, compare_determine) { ...@@ -95,7 +106,8 @@ TEST(Analyzer_Ernie, compare_determine) {
TEST(Analyzer_Ernie, compare_results) { TEST(Analyzer_Ernie, compare_results) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
auto pass_builder = cfg.pass_builder();
pass_builder->DeletePass("constant_folding_pass");
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all); LoadInputData(&input_slots_all);
......
...@@ -31,10 +31,19 @@ int GetNumOps(const AnalysisConfig &cfg) { ...@@ -31,10 +31,19 @@ int GetNumOps(const AnalysisConfig &cfg) {
return num_ops; return num_ops;
} }
/*
* this model is unreasonable, it set a output tensor persistable, so
* ridiculous! so I disable constant_folding_pass
*/
TEST(Analyzer, save_model) { TEST(Analyzer, save_model) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
cfg.SetModel(FLAGS_infer_model + "/__model__", FLAGS_infer_model + "/param"); cfg.SetModel(FLAGS_infer_model + "/__model__", FLAGS_infer_model + "/param");
auto pass_builder = cfg.pass_builder();
pass_builder->DeletePass("constant_folding_pass");
// ensure the path being unique // ensure the path being unique
std::string optimModelPath = FLAGS_infer_model + "/only_for_save_model_test"; std::string optimModelPath = FLAGS_infer_model + "/only_for_save_model_test";
MKDIR(optimModelPath.c_str()); MKDIR(optimModelPath.c_str());
...@@ -49,6 +58,8 @@ TEST(Analyzer, save_model) { ...@@ -49,6 +58,8 @@ TEST(Analyzer, save_model) {
AnalysisConfig cfg3; AnalysisConfig cfg3;
SetConfig(&cfg3); SetConfig(&cfg3);
auto pass_builder3 = cfg3.pass_builder();
pass_builder3->DeletePass("constant_folding_pass");
cfg3.SetModel(optimModelPath + "/model", optimModelPath + "/params"); cfg3.SetModel(optimModelPath + "/model", optimModelPath + "/params");
int fused_num_ops = GetNumOps(cfg3); int fused_num_ops = GetNumOps(cfg3);
CHECK_LE(fused_num_ops, origin_num_ops); CHECK_LE(fused_num_ops, origin_num_ops);
......
...@@ -40,7 +40,7 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) { ...@@ -40,7 +40,7 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) {
EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 0); EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 0);
EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2); EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops; LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 185); EXPECT_EQ(num_ops, 183);
} }
} // namespace seq_pool1_tester } // namespace seq_pool1_tester
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册