未验证 提交 4048cfa9 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #15048 from NHZlX/add_affine_channel_fuse

Add conv+ affine channel fuse pass 
...@@ -45,6 +45,7 @@ pass_library(is_test_pass base) ...@@ -45,6 +45,7 @@ pass_library(is_test_pass base)
pass_library(conv_elementwise_add_act_fuse_pass inference) pass_library(conv_elementwise_add_act_fuse_pass inference)
pass_library(conv_elementwise_add2_act_fuse_pass inference) pass_library(conv_elementwise_add2_act_fuse_pass inference)
pass_library(conv_elementwise_add_fuse_pass inference) pass_library(conv_elementwise_add_fuse_pass inference)
pass_library(conv_affine_channel_fuse_pass inference)
if(WITH_MKLDNN) if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base) pass_library(mkldnn_placement_pass base)
pass_library(depthwise_conv_mkldnn_pass base) pass_library(depthwise_conv_mkldnn_pass base)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h"
#include <functional>
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_CONV_BN_NODES(pattern_name) \
/* OPERATORS */ \
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(affine_channel, affine_channel, pattern_name); \
/* CONV inputs */ \
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \
/* CONV outputs */ \
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \
/* Affine Channel inputs */ \
GET_IR_NODE_FROM_SUBGRAPH(ac_scale, ac_scale, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(ac_bias, ac_bias, pattern_name); \
/* Affine channel outputs */ \
GET_IR_NODE_FROM_SUBGRAPH(ac_out, ac_out, pattern_name); /* Out */
void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
const ir::Node& ac_scale,
const LoDTensor& ac_bias_tensor,
LoDTensor* eltwise_y_in_tensor) {
using EigenVectorArrayMap =
Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
using EigenMatrixArrayMap = Eigen::Map<
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
// Re-compute bias of conv2d from AffineChannel
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), ac_bias_tensor.dims());
auto* scale_tensor = scope->FindVar(ac_scale.Name())->GetMutable<LoDTensor>();
ConstEigenVectorArrayMap scale_array(scale_tensor->data<float>(),
scale_tensor->numel(), 1);
ConstEigenVectorArrayMap ac_bias_array(ac_bias_tensor.data<float>(),
ac_bias_tensor.numel(), 1);
EigenVectorArrayMap eltwise_y_in_array(
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 1);
eltwise_y_in_array = (eltwise_y_in_array * scale_array) + ac_bias_array;
// Re-compute weight of conv2d from AffineChannel
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto weights_shape = weights->dims();
auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
EigenMatrixArrayMap weights_array_2d(
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],
weights_shape_2d[1]);
weights_array_2d.colwise() *= scale_array;
}
std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(),
name_scope_);
conv_ac_pattern(conv_input, false /*with_eltwise_add*/);
int found_conv_ac_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvAffineChannel fuse";
GET_CONV_BN_NODES(conv_ac_pattern);
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *affine_channel);
if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+affinechannel fuse";
return;
}
// Create eltwise_y (conv bias) variable
VarDesc eltwise_y_in_desc(
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
eltwise_y_in_desc.SetPersistable(true);
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
auto* eltwise_y_in_tensor =
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
// Get affine_channel bias
auto* ac_bias_tensor =
scope->FindVar(ac_bias->Name())->GetMutable<LoDTensor>();
// Initialize eltwise_y
eltwise_y_in_tensor->Resize(ac_bias_tensor->dims());
std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 0.0f);
// update weights and biases
recompute_bias_and_weights(scope, conv_weight, *ac_scale, *ac_bias_tensor,
eltwise_y_in_tensor);
// create an elementwise add node.
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
desc.SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
desc.SetType("elementwise_add");
desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {ac_scale, ac_bias, affine_channel});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
IR_NODE_LINK_TO(eltwise_op, ac_out);
found_conv_ac_count++;
};
gpd(graph.get(), handler);
AddStatis(found_conv_ac_count);
return graph;
}
std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(),
name_scope_);
conv_ac_pattern(conv_input, true /*with_eltwise_add*/);
int found_conv_ac_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvBN fuse";
GET_CONV_BN_NODES(conv_ac_pattern);
// OPERATORS
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_ac_pattern);
// BIAS inputs
GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_ac_pattern);
// BIAS outputs
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_ac_pattern);
// Get eltwise_y (conv bias) variable
auto* eltwise_y_in_tensor =
scope->FindVar(eltwise_y_in->Name())->GetMutable<LoDTensor>();
// Get batch norm bias
auto* ac_bias_tensor =
scope->FindVar(ac_bias->Name())->GetMutable<LoDTensor>();
recompute_bias_and_weights(scope, conv_weight, *ac_scale, *ac_bias_tensor,
eltwise_y_in_tensor);
// Update the elementwise_add node
eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
GraphSafeRemoveNodes(graph.get(),
{ac_scale, ac_bias, affine_channel, eltwise_out});
IR_NODE_LINK_TO(eltwise, ac_out);
found_conv_ac_count++;
};
gpd(graph.get(), handler);
AddStatis(found_conv_ac_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_affine_channel_fuse_pass,
paddle::framework::ir::ConvAffineChannelFusePass);
REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass,
paddle::framework::ir::ConvEltwiseAddAffineChannelFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the Conv and ConvAffineChannel.
*/
class ConvAffineChannelFusePass : public FusePassBase {
public:
virtual ~ConvAffineChannelFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_affine_channel_fuse"};
};
class ConvEltwiseAddAffineChannelFusePass : public FusePassBase {
public:
virtual ~ConvEltwiseAddAffineChannelFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -1234,6 +1234,78 @@ PDNode *patterns::ConvElementwiseadd::operator()(PDNode *conv_in) { ...@@ -1234,6 +1234,78 @@ PDNode *patterns::ConvElementwiseadd::operator()(PDNode *conv_in) {
return elementwise_add_out; return elementwise_add_out;
} }
PDNode *patterns::ConvAffineChannel::operator()(
paddle::framework::ir::PDNode *conv_input, bool with_eltwise_add) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
PDNode *eltwise_op = nullptr;
if (with_eltwise_add) {
eltwise_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
}
auto *affine_channel_op =
pattern->NewNode(affine_channel_repr())->assert_is_op("affine_channel");
// Create variables
// Conv Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d");
PDNode *eltwise_y_in_var = nullptr;
PDNode *eltwise_out_var = nullptr;
if (with_eltwise_add) {
// Conv output as Bias input
conv_out_var->assert_is_op_input("elementwise_add", "X");
// Bias
eltwise_y_in_var = pattern->NewNode(eltwise_y_in_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
eltwise_out_var = pattern->NewNode(eltwise_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("elementwise_add");
} else {
// Conv output as AffineChannel input
conv_out_var->assert_is_op_input("affine_channel", "X");
}
// AC Scale
auto *ac_scale_var = pattern->NewNode(ac_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("affine_channel", "Scale");
// AC Bias
auto *ac_bias_var = pattern->NewNode(ac_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("affine_channel", "Bias");
// AC output
auto *ac_out_var = pattern->NewNode(ac_out_repr())
->AsOutput()
->assert_is_op_output("affine_channel");
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
if (with_eltwise_add) {
eltwise_op->LinksFrom({conv_out_var, eltwise_y_in_var})
.LinksTo({eltwise_out_var});
affine_channel_op->LinksFrom({eltwise_out_var, ac_scale_var, ac_bias_var})
.LinksTo({ac_out_var});
} else {
affine_channel_op->LinksFrom({conv_out_var, ac_scale_var, ac_bias_var})
.LinksTo({ac_out_var});
}
return ac_out_var;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -734,6 +734,38 @@ struct ConvElementwiseadd : public PatternBase { ...@@ -734,6 +734,38 @@ struct ConvElementwiseadd : public PatternBase {
PATTERN_DECL_NODE(elementwise_add_out); PATTERN_DECL_NODE(elementwise_add_out);
}; };
// Conv with affine_channel
// op: conv + (elementwise_add +) affine_channel
// named nodes:
// conv_weight, conv_out, conv,
// ac_x, ac_scale, ac_bias
// affine_channel, ac_out
struct ConvAffineChannel : public PatternBase {
ConvAffineChannel(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_affine_channel") {}
PDNode* operator()(PDNode* conv_input, bool with_eltwise_add);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(affine_channel);
PATTERN_DECL_NODE(eltwise); // ELEMENTWISE_ADD
// CONV inputs
PATTERN_DECL_NODE(conv_weight); // Filter
// CONV outputs
PATTERN_DECL_NODE(conv_out); // tmp
// ELTWISE inputs
PATTERN_DECL_NODE(eltwise_y_in);
// ELTWISE outputs
PATTERN_DECL_NODE(eltwise_out); // tmp
// AC(Affine_Channel) inputs
PATTERN_DECL_NODE(ac_scale);
PATTERN_DECL_NODE(ac_bias);
// AC outputs
PATTERN_DECL_NODE(ac_out); // Out
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.
......
...@@ -127,6 +127,7 @@ struct Argument { ...@@ -127,6 +127,7 @@ struct Argument {
std::function<bool(const framework::ir::Node*)>); std::function<bool(const framework::ir::Node*)>);
DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int); DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int); DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int);
DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
// The program transformed by IR analysis phase. // The program transformed by IR analysis phase.
DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program, IrAnalyzedProgram, DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program, IrAnalyzedProgram,
......
...@@ -75,6 +75,8 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -75,6 +75,8 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->tensorrt_node_teller_ptr()); argument->tensorrt_node_teller_ptr());
pass->Set("workspace_size", new int(argument->tensorrt_workspace_size())); pass->Set("workspace_size", new int(argument->tensorrt_workspace_size()));
pass->Set("max_batch_size", new int(argument->tensorrt_max_batch_size())); pass->Set("max_batch_size", new int(argument->tensorrt_max_batch_size()));
pass->Set("min_subgraph_size",
new int(argument->tensorrt_min_subgraph_size()));
} }
// graph_ = pass->Apply(std::move(graph_)); // graph_ = pass->Apply(std::move(graph_));
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" #include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" #include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -36,7 +38,8 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -36,7 +38,8 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
auto teller = auto teller =
Get<SubgraphDetector::NodeInsideSubgraphTeller>("tensorrt_node_teller"); Get<SubgraphDetector::NodeInsideSubgraphTeller>("tensorrt_node_teller");
SubGraphFuser fuser(graph.get(), teller, 2 /*min subgraph size*/); SubGraphFuser fuser(graph.get(), teller,
Get<int>("min_subgraph_size") /*min subgraph size*/);
fuser(); fuser();
for (auto *node : graph->Nodes()) { for (auto *node : graph->Nodes()) {
...@@ -197,10 +200,26 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, ...@@ -197,10 +200,26 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
std::vector<std::string> ExtractParameters( std::vector<std::string> ExtractParameters(
const std::unordered_set<Node *> &nodes) { const std::unordered_set<Node *> &nodes) {
// We can judge whether a variable is a parameter by
// its presistable property, but sometimes the presistable
// of the feed op output is true, so we have to identify it.
std::vector<std::string> feed_outputs;
for (const auto &node : nodes) {
if (!node->IsOp()) continue;
std::string op_type = node->Op()->Type();
if (op_type == "feed") {
std::vector<std::string> output_names = node->Op()->OutputArgumentNames();
std::copy(output_names.begin(), output_names.end(),
std::back_inserter(feed_outputs));
}
}
std::vector<std::string> parameters; std::vector<std::string> parameters;
for (const auto &node : nodes) { for (const auto &node : nodes) {
if (!node->IsVar()) continue; if (!node->IsVar()) continue;
if (node->Var()->Persistable()) { if (node->Var()->Persistable() &&
std::find(feed_outputs.begin(), feed_outputs.end(), node->Name()) ==
feed_outputs.end()) {
parameters.push_back(node->Name()); parameters.push_back(node->Name());
} }
} }
...@@ -215,4 +234,5 @@ REGISTER_PASS(tensorrt_subgraph_pass, ...@@ -215,4 +234,5 @@ REGISTER_PASS(tensorrt_subgraph_pass,
paddle::inference::analysis::TensorRtSubgraphPass) paddle::inference::analysis::TensorRtSubgraphPass)
.RequirePassAttr("tensorrt_node_teller") .RequirePassAttr("tensorrt_node_teller")
.RequirePassAttr("max_batch_size") .RequirePassAttr("max_batch_size")
.RequirePassAttr("workspace_size"); .RequirePassAttr("workspace_size")
.RequirePassAttr("min_subgraph_size");
...@@ -57,6 +57,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) { ...@@ -57,6 +57,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
use_tensorrt_ = other.use_tensorrt_; use_tensorrt_ = other.use_tensorrt_;
tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_;
tensorrt_workspace_size_ = other.tensorrt_workspace_size_; tensorrt_workspace_size_ = other.tensorrt_workspace_size_;
tensorrt_min_subgraph_size_ = other.tensorrt_min_subgraph_size_;
model_from_memory_ = other.model_from_memory_; model_from_memory_ = other.model_from_memory_;
if (use_gpu) { if (use_gpu) {
...@@ -89,6 +90,7 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) { ...@@ -89,6 +90,7 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) {
use_tensorrt_ = other.use_tensorrt_; use_tensorrt_ = other.use_tensorrt_;
tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_;
tensorrt_workspace_size_ = other.tensorrt_workspace_size_; tensorrt_workspace_size_ = other.tensorrt_workspace_size_;
tensorrt_min_subgraph_size_ = other.tensorrt_min_subgraph_size_;
model_from_memory_ = other.model_from_memory_; model_from_memory_ = other.model_from_memory_;
pass_builder_ = std::move(other.pass_builder_); pass_builder_ = std::move(other.pass_builder_);
...@@ -105,12 +107,14 @@ void contrib::AnalysisConfig::EnableMKLDNN() { ...@@ -105,12 +107,14 @@ void contrib::AnalysisConfig::EnableMKLDNN() {
} }
void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size, void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size,
int max_batch_size) { int max_batch_size,
int min_subgraph_size) {
use_tensorrt_ = true; use_tensorrt_ = true;
tensorrt_workspace_size_ = workspace_size; tensorrt_workspace_size_ = workspace_size;
tensorrt_max_batchsize_ = max_batch_size; tensorrt_max_batchsize_ = max_batch_size;
// Append after the infer_clean pass. tensorrt_min_subgraph_size_ = min_subgraph_size;
pass_builder()->InsertPass(1, "tensorrt_subgraph_pass"); // Append after the conv+affine_channel fuse pass.
pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
} }
void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer, void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer,
......
...@@ -328,6 +328,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -328,6 +328,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetUseTensorRT(true); argument_.SetUseTensorRT(true);
argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_); argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_);
argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_); argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
} }
if (config_.use_mkldnn_) { if (config_.use_mkldnn_) {
......
...@@ -49,7 +49,7 @@ struct AnalysisConfig : public NativeConfig { ...@@ -49,7 +49,7 @@ struct AnalysisConfig : public NativeConfig {
bool use_feed_fetch_ops{true}; bool use_feed_fetch_ops{true};
void EnableTensorRtEngine(int workspace_size = 1 << 20, void EnableTensorRtEngine(int workspace_size = 1 << 20,
int max_batch_size = 1); int max_batch_size = 1, int min_subgraph_size = 3);
bool use_tensorrt() const { return use_tensorrt_; } bool use_tensorrt() const { return use_tensorrt_; }
void EnableMKLDNN(); void EnableMKLDNN();
...@@ -69,8 +69,19 @@ struct AnalysisConfig : public NativeConfig { ...@@ -69,8 +69,19 @@ struct AnalysisConfig : public NativeConfig {
bool use_tensorrt_{false}; bool use_tensorrt_{false};
bool use_mkldnn_{false}; bool use_mkldnn_{false};
std::unordered_set<std::string> mkldnn_enabled_op_types_; std::unordered_set<std::string> mkldnn_enabled_op_types_;
// For workspace_size, refer it from here:
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting
int tensorrt_workspace_size_; int tensorrt_workspace_size_;
// While TensorRT allows an engine optimized for a given max batch size
// to run at any smaller size, the performance for those smaller
// sizes may not be as well-optimized. Therefore, Max batch is best
// equivalent to the runtime batch size.
int tensorrt_max_batchsize_; int tensorrt_max_batchsize_;
// We transform the Ops that can be converted into TRT layer in the model,
// and aggregate these Ops into subgraphs for TRT execution.
// We set this variable to control the minimum number of nodes in the
// subgraph, 3 as default value.
int tensorrt_min_subgraph_size_{3};
std::unique_ptr<PassStrategy> pass_builder_; std::unique_ptr<PassStrategy> pass_builder_;
bool model_from_memory_{false}; bool model_from_memory_{false};
}; };
......
...@@ -119,6 +119,8 @@ class GpuPassStrategy : public PassStrategy { ...@@ -119,6 +119,8 @@ class GpuPassStrategy : public PassStrategy {
GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"infer_clean_graph_pass", // "infer_clean_graph_pass", //
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"conv_elementwise_add_act_fuse_pass", // "conv_elementwise_add_act_fuse_pass", //
"conv_elementwise_add2_act_fuse_pass", // "conv_elementwise_add2_act_fuse_pass", //
......
...@@ -161,9 +161,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -161,9 +161,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit"); "workspace_size to be allocated exceeds the limit");
if ((activation == "identity") && if ((activation == "identity") && (!residual)) {
(algo != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) &&
(!residual)) {
// Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
// enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib. // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib.
// But test in some case, the speed is slower, change to use // But test in some case, the speed is slower, change to use
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册