未验证 提交 1ea9971a 编写于 作者: J JingZhuangzhuang 提交者: GitHub

modify graph_pattern to thread_local (#43945)

上级 26187c27
...@@ -28,10 +28,17 @@ using string::Style; ...@@ -28,10 +28,17 @@ using string::Style;
size_t PDPattern::id_ = 0UL; size_t PDPattern::id_ = 0UL;
#ifdef PADDLE_WITH_TENSORRT
namespace patterns {
thread_local std::unordered_map<std::string, size_t> KeyCounter::dic_;
}
#endif
PDNode *PDPattern::NewNode(const std::string &name) { PDNode *PDPattern::NewNode(const std::string &name) {
if (!name.empty()) { if (!name.empty()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
node_map_.count(name), 0UL, node_map_.count(name),
0UL,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"PDNode's name should be unique, get duplicate [%s]", name)); "PDNode's name should be unique, get duplicate [%s]", name));
} }
...@@ -45,7 +52,8 @@ PDNode *PDPattern::NewNode(const std::string &name) { ...@@ -45,7 +52,8 @@ PDNode *PDPattern::NewNode(const std::string &name) {
PDNode *PDPattern::NewNode(PDNode::teller_t &&teller, const std::string &name) { PDNode *PDPattern::NewNode(PDNode::teller_t &&teller, const std::string &name) {
if (!name.empty()) { if (!name.empty()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
node_map_.count(name), 0UL, node_map_.count(name),
0UL,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"PDNode's name should be unique, get duplicate [%s]", name)); "PDNode's name should be unique, get duplicate [%s]", name));
} }
...@@ -70,7 +78,9 @@ void PDPattern::AddEdge(PDNode *a, PDNode *b) { ...@@ -70,7 +78,9 @@ void PDPattern::AddEdge(PDNode *a, PDNode *b) {
a, platform::errors::NotFound("PDNode %s is not found.", a->name())); a, platform::errors::NotFound("PDNode %s is not found.", a->name()));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
b, platform::errors::NotFound("PDNode %s is not found.", b->name())); b, platform::errors::NotFound("PDNode %s is not found.", b->name()));
PADDLE_ENFORCE_NE(a, b, platform::errors::PermissionDenied( PADDLE_ENFORCE_NE(a,
b,
platform::errors::PermissionDenied(
"Cannot connect the same node in the graph.")); "Cannot connect the same node in the graph."));
edges_.emplace_back(a, b); edges_.emplace_back(a, b);
} }
...@@ -128,7 +138,8 @@ void GraphPatternDetector::ValidateByNodeRole( ...@@ -128,7 +138,8 @@ void GraphPatternDetector::ValidateByNodeRole(
subgraphs->erase( subgraphs->erase(
std::remove_if( std::remove_if(
subgraphs->begin(), subgraphs->end(), subgraphs->begin(),
subgraphs->end(),
[](const GraphPatternDetector::subgraph_t &subgraph) -> bool { [](const GraphPatternDetector::subgraph_t &subgraph) -> bool {
// Collect the inputs and outputs. // Collect the inputs and outputs.
std::set<Node *> ios; std::set<Node *> ios;
...@@ -310,7 +321,8 @@ void GraphPatternDetector::SortSubgraphs( ...@@ -310,7 +321,8 @@ void GraphPatternDetector::SortSubgraphs(
} }
std::sort( std::sort(
subgraphs->begin(), subgraphs->end(), subgraphs->begin(),
subgraphs->end(),
[](const GraphPatternDetector::subgraph_t &a, [](const GraphPatternDetector::subgraph_t &a,
const GraphPatternDetector::subgraph_t &b) { const GraphPatternDetector::subgraph_t &b) {
for (auto &item : a) { for (auto &item : a) {
...@@ -438,7 +450,8 @@ PDNode *PDNode::assert_is_persistable_var() { ...@@ -438,7 +450,8 @@ PDNode *PDNode::assert_is_persistable_var() {
} }
PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type, PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type,
const std::string &argument, int nth) { const std::string &argument,
int nth) {
assert_is_var(); assert_is_var();
assert_is_op_input(op_type); assert_is_op_input(op_type);
asserts_.emplace_back([=](Node *x) { asserts_.emplace_back([=](Node *x) {
...@@ -453,7 +466,8 @@ PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type, ...@@ -453,7 +466,8 @@ PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type,
} }
PDNode *PDNode::assert_is_op_nth_output(const std::string &op_type, PDNode *PDNode::assert_is_op_nth_output(const std::string &op_type,
const std::string &argument, int nth) { const std::string &argument,
int nth) {
assert_is_var(); assert_is_var();
asserts_.emplace_back([=](Node *x) { asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) { for (auto *op : x->inputs) {
...@@ -580,7 +594,8 @@ PDNode *PDNode::assert_is_ops(const std::unordered_set<std::string> &op_types) { ...@@ -580,7 +594,8 @@ PDNode *PDNode::assert_is_ops(const std::unordered_set<std::string> &op_types) {
PDNode *PDNode::assert_is_ops_nth_input( PDNode *PDNode::assert_is_ops_nth_input(
const std::unordered_set<std::string> &op_types, const std::unordered_set<std::string> &op_types,
const std::string &argument, int nth) { const std::string &argument,
int nth) {
assert_is_var(); assert_is_var();
assert_is_ops_input(op_types); assert_is_ops_input(op_types);
asserts_.emplace_back([=](Node *x) { asserts_.emplace_back([=](Node *x) {
...@@ -596,7 +611,8 @@ PDNode *PDNode::assert_is_ops_nth_input( ...@@ -596,7 +611,8 @@ PDNode *PDNode::assert_is_ops_nth_input(
PDNode *PDNode::assert_is_ops_nth_output( PDNode *PDNode::assert_is_ops_nth_output(
const std::unordered_set<std::string> &op_types, const std::unordered_set<std::string> &op_types,
const std::string &argument, int nth) { const std::string &argument,
int nth) {
assert_is_var(); assert_is_var();
asserts_.emplace_back([=](Node *x) { asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) { for (auto *op : x->inputs) {
...@@ -693,11 +709,13 @@ bool VarLinksToOp(Node *node, const std::string &op_type) { ...@@ -693,11 +709,13 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) { bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var->IsVar(), true, var->IsVar(),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"First parameter of function IsNthInput must be Node::Var")); "First parameter of function IsNthInput must be Node::Var"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op->IsOp(), true, op->IsOp(),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Second parameter of function IsNthInput must be Node::Op")); "Second parameter of function IsNthInput must be Node::Op"));
if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth) if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth)
...@@ -707,7 +725,8 @@ bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) { ...@@ -707,7 +725,8 @@ bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
bool HasInput(Node *op, const std::string &argument) { bool HasInput(Node *op, const std::string &argument) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op->IsOp(), true, op->IsOp(),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"First parameter of function HasInput must be Node::Op")); "First parameter of function HasInput must be Node::Op"));
auto const &names = op->Op()->InputNames(); auto const &names = op->Op()->InputNames();
...@@ -718,7 +737,8 @@ bool HasInput(Node *op, const std::string &argument) { ...@@ -718,7 +737,8 @@ bool HasInput(Node *op, const std::string &argument) {
bool HasOutput(Node *op, const std::string &argument) { bool HasOutput(Node *op, const std::string &argument) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op->IsOp(), true, op->IsOp(),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"First parameter of function HasOuput must be Node::Op")); "First parameter of function HasOuput must be Node::Op"));
auto const &names = op->Op()->OutputNames(); auto const &names = op->Op()->OutputNames();
...@@ -729,11 +749,13 @@ bool HasOutput(Node *op, const std::string &argument) { ...@@ -729,11 +749,13 @@ bool HasOutput(Node *op, const std::string &argument) {
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var->IsVar(), true, var->IsVar(),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"First parameter of function IsNthOutput must be Node::Var")); "First parameter of function IsNthOutput must be Node::Var"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op->IsOp(), true, op->IsOp(),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Second parameter of function IsNthOutput must be Node::Op")); "Second parameter of function IsNthOutput must be Node::Op"));
if (!HasOutput(op, argument) || op->Op()->Output(argument).size() <= nth) if (!HasOutput(op, argument) || op->Op()->Output(argument).size() <= nth)
...@@ -875,22 +897,35 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, ...@@ -875,22 +897,35 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
eltwise_op->LinksFrom({conv_out_var, eltwise_y_in_var}) eltwise_op->LinksFrom({conv_out_var, eltwise_y_in_var})
.LinksTo({eltwise_out_var}); .LinksTo({eltwise_out_var});
batch_norm_op batch_norm_op
->LinksFrom({eltwise_out_var, bn_scale_var, bn_bias_var, bn_mean_var, ->LinksFrom({eltwise_out_var,
bn_scale_var,
bn_bias_var,
bn_mean_var,
bn_variance_var}) bn_variance_var})
.LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var, .LinksTo({bn_out_var,
bn_saved_mean_var, bn_saved_variance_var}); bn_mean_out_var,
bn_variance_out_var,
bn_saved_mean_var,
bn_saved_variance_var});
} else { } else {
batch_norm_op batch_norm_op
->LinksFrom({conv_out_var, bn_scale_var, bn_bias_var, bn_mean_var, ->LinksFrom({conv_out_var,
bn_scale_var,
bn_bias_var,
bn_mean_var,
bn_variance_var}) bn_variance_var})
.LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var, .LinksTo({bn_out_var,
bn_saved_mean_var, bn_saved_variance_var}); bn_mean_out_var,
bn_variance_out_var,
bn_saved_mean_var,
bn_saved_variance_var});
} }
return bn_out_var; return bn_out_var;
} }
PDNode *patterns::ConvActivation::operator()( PDNode *patterns::ConvActivation::operator()(
paddle::framework::ir::PDNode *conv_input, std::string conv_type, paddle::framework::ir::PDNode *conv_input,
std::string conv_type,
std::string activation_type) { std::string activation_type) {
// Create Operators // Create Operators
conv_input->assert_is_op_input(conv_type, "Input"); conv_input->assert_is_op_input(conv_type, "Input");
...@@ -920,7 +955,8 @@ PDNode *patterns::ConvActivation::operator()( ...@@ -920,7 +955,8 @@ PDNode *patterns::ConvActivation::operator()(
PDNode *patterns::ElementwiseActivation::operator()( PDNode *patterns::ElementwiseActivation::operator()(
paddle::framework::ir::PDNode *elementwise_a, paddle::framework::ir::PDNode *elementwise_a,
const std::string &elementwise_type, const std::string &activation_type) { const std::string &elementwise_type,
const std::string &activation_type) {
// Create Operators // Create Operators
elementwise_a->assert_is_op_input(elementwise_type, "X"); elementwise_a->assert_is_op_input(elementwise_type, "X");
auto *elementwise_op = auto *elementwise_op =
...@@ -995,7 +1031,8 @@ PDNode *patterns::SeqConvEltAddRelu::operator()( ...@@ -995,7 +1031,8 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
} }
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
bool with_bias, bool with_relu) { bool with_bias,
bool with_relu) {
// Create shared nodes. // Create shared nodes.
x->assert_is_op_input("mul", "X"); x->assert_is_op_input("mul", "X");
auto *mul = pattern->NewNode(mul_repr())->assert_is_op("mul"); auto *mul = pattern->NewNode(mul_repr())->assert_is_op("mul");
...@@ -1261,8 +1298,12 @@ PDNode *patterns::BatchNormAct::operator()( ...@@ -1261,8 +1298,12 @@ PDNode *patterns::BatchNormAct::operator()(
bn->LinksFrom( bn->LinksFrom(
{bn_x_var, bn_scale_var, bn_bias_var, bn_variance_var, bn_mean_var}) {bn_x_var, bn_scale_var, bn_bias_var, bn_variance_var, bn_mean_var})
.LinksTo({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var, .LinksTo({bn_mean_out_var,
bn_saved_mean_var, bn_reserve_space, bn_out_var}); bn_variance_out_var,
bn_saved_variance_var,
bn_saved_mean_var,
bn_reserve_space,
bn_out_var});
act->LinksFrom({bn_out_var}).LinksTo({act_out_var}); act->LinksFrom({bn_out_var}).LinksTo({act_out_var});
return act_out_var; return act_out_var;
...@@ -1319,8 +1360,13 @@ PDNode *patterns::BatchNormActGrad::operator()( ...@@ -1319,8 +1360,13 @@ PDNode *patterns::BatchNormActGrad::operator()(
.LinksTo({d_intermediate_var}); .LinksTo({d_intermediate_var});
bn_grad bn_grad
->LinksFrom({bn_x_var, d_intermediate_var, bn_scale_var, bn_bias_var, ->LinksFrom({bn_x_var,
bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space}) d_intermediate_var,
bn_scale_var,
bn_bias_var,
bn_saved_mean_var,
bn_saved_variance_var,
bn_reserve_space})
.LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var}); .LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var});
return bn_grad; return bn_grad;
...@@ -1404,8 +1450,12 @@ PDNode *patterns::BatchNormAddAct::operator()( ...@@ -1404,8 +1450,12 @@ PDNode *patterns::BatchNormAddAct::operator()(
pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out"); pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out");
bn->LinksFrom({bn_x_var, bn_scale_var, bn_bias_var}) bn->LinksFrom({bn_x_var, bn_scale_var, bn_bias_var})
.LinksTo({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var, .LinksTo({bn_mean_out_var,
bn_saved_mean_var, bn_reserve_space, bn_out_var}); bn_variance_out_var,
bn_saved_variance_var,
bn_saved_mean_var,
bn_reserve_space,
bn_out_var});
elewise_add->LinksFrom({elewise_add_in_var, bn_out_var}) elewise_add->LinksFrom({elewise_add_in_var, bn_out_var})
.LinksTo({elewise_add_out_var}); .LinksTo({elewise_add_out_var});
act->LinksFrom({elewise_add_out_var}).LinksTo({act_out_var}); act->LinksFrom({elewise_add_out_var}).LinksTo({act_out_var});
...@@ -1484,8 +1534,13 @@ PDNode *patterns::BatchNormAddActGrad::operator()( ...@@ -1484,8 +1534,13 @@ PDNode *patterns::BatchNormAddActGrad::operator()(
.LinksTo({d_elewise_add_in_var, d_bn_out_var}); .LinksTo({d_elewise_add_in_var, d_bn_out_var});
bn_grad bn_grad
->LinksFrom({bn_x_var, d_bn_out_var, bn_scale_var, bn_bias_var, ->LinksFrom({bn_x_var,
bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space}) d_bn_out_var,
bn_scale_var,
bn_bias_var,
bn_saved_mean_var,
bn_saved_variance_var,
bn_reserve_space})
.LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var}); .LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var});
return bn_grad; return bn_grad;
...@@ -1558,7 +1613,8 @@ PDNode *patterns::ElewiseAddAct::operator()( ...@@ -1558,7 +1613,8 @@ PDNode *patterns::ElewiseAddAct::operator()(
PDNode *patterns::LinearAct::operator()( PDNode *patterns::LinearAct::operator()(
paddle::framework::ir::PDNode *linear_x_var, paddle::framework::ir::PDNode *linear_x_var,
const std::unordered_set<std::string> &act_types, bool with_grad_link, const std::unordered_set<std::string> &act_types,
bool with_grad_link,
bool is_act_grad_x_from_act) { bool is_act_grad_x_from_act) {
auto *matmul_w_var = auto *matmul_w_var =
pattern->NewNode(matmul_w_repr())->assert_is_op_input("matmul_v2", "Y"); pattern->NewNode(matmul_w_repr())->assert_is_op_input("matmul_v2", "Y");
...@@ -1621,7 +1677,8 @@ PDNode *patterns::LinearAct::operator()( ...@@ -1621,7 +1677,8 @@ PDNode *patterns::LinearAct::operator()(
PDNode *patterns::ElewiseAddMatmulAct::operator()( PDNode *patterns::ElewiseAddMatmulAct::operator()(
paddle::framework::ir::PDNode *dout_var, paddle::framework::ir::PDNode *dout_var,
const std::unordered_set<std::string> &act_grad_types, const std::unordered_set<std::string> &act_grad_types,
bool without_x_gradient, bool is_act_grad_x_from_act) { bool without_x_gradient,
bool is_act_grad_x_from_act) {
auto *ele_grad_bias_var = auto *ele_grad_bias_var =
pattern->NewNode(ele_grad_bias_repr()) pattern->NewNode(ele_grad_bias_repr())
->assert_is_op_input("elementwise_add_grad", "Y"); ->assert_is_op_input("elementwise_add_grad", "Y");
...@@ -2052,7 +2109,8 @@ PDNode *patterns::Pool::operator()() { ...@@ -2052,7 +2109,8 @@ PDNode *patterns::Pool::operator()() {
return output_var; return output_var;
} }
PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var, PDNode *patterns::Elementwise::operator()(PDNode *x_var,
PDNode *y_var,
const std::string elementwise_type) { const std::string elementwise_type) {
auto elementwise_op = auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
...@@ -2084,7 +2142,9 @@ PDNode *patterns::ElementwiseOp::operator()( ...@@ -2084,7 +2142,9 @@ PDNode *patterns::ElementwiseOp::operator()(
} }
PDNode *patterns::ResidualElementwise::operator()( PDNode *patterns::ResidualElementwise::operator()(
PDNode *op_var, PDNode *residual_var, const std::string elementwise_type, PDNode *op_var,
PDNode *residual_var,
const std::string elementwise_type,
bool as_x) { bool as_x) {
auto elementwise_op = auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
...@@ -3065,7 +3125,8 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() { ...@@ -3065,7 +3125,8 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
} }
PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
const std::string &op_name, bool with_reshape_xshape, const std::string &op_name,
bool with_reshape_xshape,
bool with_transpose_xshape) { bool with_transpose_xshape) {
auto reshape_op = auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2"); pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
...@@ -3098,8 +3159,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( ...@@ -3098,8 +3159,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
transpose_out->assert_is_only_output_of_op("transpose2"); transpose_out->assert_is_only_output_of_op("transpose2");
auto transpose_xshape = auto transpose_xshape =
with_transpose_xshape with_transpose_xshape ? pattern->NewNode(transpose_xshape_repr())
? pattern->NewNode(transpose_xshape_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_op_output("transpose2", "XShape") ->assert_is_op_output("transpose2", "XShape")
: nullptr; : nullptr;
......
...@@ -122,10 +122,12 @@ struct PDNode { ...@@ -122,10 +122,12 @@ struct PDNode {
PDNode* assert_is_op_input(const std::string& op_type, PDNode* assert_is_op_input(const std::string& op_type,
const std::string& argument); const std::string& argument);
PDNode* assert_is_op_nth_input(const std::string& op_type, PDNode* assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth); const std::string& argument,
int nth);
PDNode* assert_is_not_op_input(const std::string& argument); PDNode* assert_is_not_op_input(const std::string& argument);
PDNode* assert_is_op_nth_output(const std::string& op_type, PDNode* assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth); const std::string& argument,
int nth);
PDNode* assert_is_only_input_of_op(const std::string& op_type); PDNode* assert_is_only_input_of_op(const std::string& op_type);
PDNode* assert_is_only_output_of_op(const std::string& op_type); PDNode* assert_is_only_output_of_op(const std::string& op_type);
PDNode* assert_op_has_n_inputs(const std::string& op_type, size_t n); PDNode* assert_op_has_n_inputs(const std::string& op_type, size_t n);
...@@ -138,13 +140,15 @@ struct PDNode { ...@@ -138,13 +140,15 @@ struct PDNode {
const std::string& argument); const std::string& argument);
PDNode* assert_is_ops_nth_input( PDNode* assert_is_ops_nth_input(
const std::unordered_set<std::string>& op_types, const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth); const std::string& argument,
int nth);
PDNode* assert_is_ops_input(const std::unordered_set<std::string>& op_types); PDNode* assert_is_ops_input(const std::unordered_set<std::string>& op_types);
PDNode* assert_is_ops_input(const std::unordered_set<std::string>& op_types, PDNode* assert_is_ops_input(const std::unordered_set<std::string>& op_types,
const std::string& argument); const std::string& argument);
PDNode* assert_is_ops_nth_output( PDNode* assert_is_ops_nth_output(
const std::unordered_set<std::string>& op_types, const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth); const std::string& argument,
int nth);
PDNode* assert_is_only_input_of_ops( PDNode* assert_is_only_input_of_ops(
const std::unordered_set<std::string>& op_types); const std::unordered_set<std::string>& op_types);
...@@ -164,10 +168,13 @@ struct PDNode { ...@@ -164,10 +168,13 @@ struct PDNode {
} }
private: private:
PDNode(PDPattern* pattern, const std::string& name = "", PDNode(PDPattern* pattern,
const std::string& name = "",
Type type = Type::kVar) Type type = Type::kVar)
: pattern_(pattern), name_(name), type_(type) {} : pattern_(pattern), name_(name), type_(type) {}
PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "", PDNode(teller_t&& teller,
PDPattern* pattern,
const std::string& name = "",
Type type = Type::kVar) Type type = Type::kVar)
: teller_(std::move(teller)), : teller_(std::move(teller)),
pattern_(pattern), pattern_(pattern),
...@@ -398,16 +405,25 @@ struct KeyCounter { ...@@ -398,16 +405,25 @@ struct KeyCounter {
return x; return x;
} }
#ifdef PADDLE_WITH_TENSORRT
static int IncCounter(const std::string& key) { return dic_[key]++; }
static void CleanCounter() { dic_.clear(); }
private:
static thread_local std::unordered_map<std::string, size_t> dic_;
#else
int IncCounter(const std::string& key) { return dic_[key]++; } int IncCounter(const std::string& key) { return dic_[key]++; }
private: private:
std::unordered_map<std::string, size_t> dic_; std::unordered_map<std::string, size_t> dic_;
#endif
}; };
// Generate a unique PDNode's name with name_scope and id. // Generate a unique PDNode's name with name_scope and id.
// The format is {name_scope}/{repr}/{id}/{name} // The format is {name_scope}/{repr}/{id}/{name}
static std::string PDNodeName(const std::string& name_scope, static std::string PDNodeName(const std::string& name_scope,
const std::string& repr, size_t id, const std::string& repr,
size_t id,
const std::string& name) { const std::string& name) {
return string::Sprintf("%s/%s/%d/%s", name_scope, repr, id, name); return string::Sprintf("%s/%s/%d/%s", name_scope, repr, id, name);
} }
...@@ -415,15 +431,15 @@ static std::string PDNodeName(const std::string& name_scope, ...@@ -415,15 +431,15 @@ static std::string PDNodeName(const std::string& name_scope,
// The format is {name_scope}/{repr}/{id} // The format is {name_scope}/{repr}/{id}
static std::string PDNodeName(const std::string& name_scope, static std::string PDNodeName(const std::string& name_scope,
const std::string& repr) { const std::string& repr) {
return string::Sprintf("%s/%s/%d", name_scope, repr, return string::Sprintf(
KeyCounter::Instance().IncCounter(repr)); "%s/%s/%d", name_scope, repr, KeyCounter::Instance().IncCounter(repr));
} }
// Generate a unique key. It can be used for a universally unique temporary // Generate a unique key. It can be used for a universally unique temporary
// name. // name.
// The format is {repr}/{id} // The format is {repr}/{id}
static std::string UniqueKey(const std::string& repr) { static std::string UniqueKey(const std::string& repr) {
return string::Sprintf("%s/%d", repr, return string::Sprintf(
KeyCounter::Instance().IncCounter(repr)); "%s/%d", repr, KeyCounter::Instance().IncCounter(repr));
} }
// Declare a PDNode in a pattern, will create two methods: // Declare a PDNode in a pattern, will create two methods:
...@@ -440,17 +456,19 @@ static std::string UniqueKey(const std::string& repr) { ...@@ -440,17 +456,19 @@ static std::string UniqueKey(const std::string& repr) {
// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition. // arg: the argument declared by PATTERN_DECL_NODE in a pattern definition.
// pat: the pattern object. // pat: the pattern object.
#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \ #define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \
PADDLE_ENFORCE_NE(subgraph.count(pat.arg##_n()), 0UL, \ PADDLE_ENFORCE_NE(subgraph.count(pat.arg##_n()), \
0UL, \
platform::errors::NotFound("Node not found for PDNode %s", \ platform::errors::NotFound("Node not found for PDNode %s", \
pat.arg##_repr())); \ pat.arg##_repr())); \
Node* var = subgraph.at(pat.arg##_n()); \ Node* var = subgraph.at(pat.arg##_n()); \
PADDLE_ENFORCE_NOT_NULL( \ PADDLE_ENFORCE_NOT_NULL(var, \
var, platform::errors::NotFound("node %s not exists in the sub-graph", \ platform::errors::NotFound( \
#arg)); "node %s not exists in the sub-graph", #arg));
// The base class of all the patterns. // The base class of all the patterns.
struct PatternBase { struct PatternBase {
PatternBase(PDPattern* pattern, const std::string& name_scope, PatternBase(PDPattern* pattern,
const std::string& name_scope,
const std::string& repr) const std::string& repr)
: pattern(pattern), : pattern(pattern),
name_scope_(name_scope), name_scope_(name_scope),
...@@ -476,7 +494,8 @@ struct ConvBN : public PatternBase { ...@@ -476,7 +494,8 @@ struct ConvBN : public PatternBase {
ConvBN(PDPattern* pattern, const std::string& name_scope) ConvBN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bn") {} : PatternBase(pattern, name_scope, "conv_bn") {}
PDNode* operator()(PDNode* conv_input, const std::string& conv_type, PDNode* operator()(PDNode* conv_input,
const std::string& conv_type,
bool with_eltwise_add); bool with_eltwise_add);
// declare operator node's name // declare operator node's name
...@@ -514,7 +533,8 @@ struct ConvActivation : public PatternBase { ...@@ -514,7 +533,8 @@ struct ConvActivation : public PatternBase {
ConvActivation(PDPattern* pattern, const std::string& name_scope) ConvActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_activation") {} : PatternBase(pattern, name_scope, "conv_activation") {}
PDNode* operator()(PDNode* conv_input, std::string conv_type = "conv2d", PDNode* operator()(PDNode* conv_input,
std::string conv_type = "conv2d",
std::string activation_type = "relu"); std::string activation_type = "relu");
// declare operator node's name // declare operator node's name
...@@ -536,7 +556,8 @@ struct ElementwiseActivation : public PatternBase { ...@@ -536,7 +556,8 @@ struct ElementwiseActivation : public PatternBase {
ElementwiseActivation(PDPattern* pattern, const std::string& name_scope) ElementwiseActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add_activation") {} : PatternBase(pattern, name_scope, "elementwise_add_activation") {}
PDNode* operator()(PDNode* elementwise_a, const std::string& elementwise_type, PDNode* operator()(PDNode* elementwise_a,
const std::string& elementwise_type,
const std::string& activation_type); const std::string& activation_type);
// declare operator node's name // declare operator node's name
...@@ -936,7 +957,8 @@ struct LinearAct : public PatternBase { ...@@ -936,7 +957,8 @@ struct LinearAct : public PatternBase {
PDNode* operator()(PDNode* x, PDNode* operator()(PDNode* x,
const std::unordered_set<std::string>& act_types, const std::unordered_set<std::string>& act_types,
bool with_grad_link, bool is_act_grad_x_from_act); bool with_grad_link,
bool is_act_grad_x_from_act);
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(matmul); PATTERN_DECL_NODE(matmul);
...@@ -965,7 +987,8 @@ struct ElewiseAddMatmulAct : public PatternBase { ...@@ -965,7 +987,8 @@ struct ElewiseAddMatmulAct : public PatternBase {
PDNode* operator()(PDNode* x, PDNode* operator()(PDNode* x,
const std::unordered_set<std::string>& act_grad_types, const std::unordered_set<std::string>& act_grad_types,
bool without_x_gradient, bool is_act_grad_x_from_act); bool without_x_gradient,
bool is_act_grad_x_from_act);
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(ele_add_grad); PATTERN_DECL_NODE(ele_add_grad);
...@@ -1062,7 +1085,8 @@ struct Elementwise : public PatternBase { ...@@ -1062,7 +1085,8 @@ struct Elementwise : public PatternBase {
Elementwise(PDPattern* pattern, const std::string& name_scope) Elementwise(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise") {} : PatternBase(pattern, name_scope, "elementwise") {}
PDNode* operator()(PDNode* x_var, PDNode* y_var, PDNode* operator()(PDNode* x_var,
PDNode* y_var,
const std::string elementwise_type); const std::string elementwise_type);
PATTERN_DECL_NODE(elementwise_op); PATTERN_DECL_NODE(elementwise_op);
...@@ -1088,11 +1112,14 @@ struct ElementwiseOp : public PatternBase { ...@@ -1088,11 +1112,14 @@ struct ElementwiseOp : public PatternBase {
// This pattern allows operator output to be X or Y // This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag // and residual data Y or X, based on as_x flag
struct ResidualElementwise : public PatternBase { struct ResidualElementwise : public PatternBase {
ResidualElementwise(PDPattern* pattern, const std::string& name_scope, ResidualElementwise(PDPattern* pattern,
const std::string& name_scope,
bool as_x) bool as_x)
: PatternBase(pattern, name_scope, "residual_elementwise") {} : PatternBase(pattern, name_scope, "residual_elementwise") {}
PDNode* operator()(PDNode* op_var, PDNode* residual_var, PDNode* operator()(PDNode* op_var,
const std::string elementwise_type, bool as_x); PDNode* residual_var,
const std::string elementwise_type,
bool as_x);
PATTERN_DECL_NODE(operator_output); PATTERN_DECL_NODE(operator_output);
PATTERN_DECL_NODE(residual_data); PATTERN_DECL_NODE(residual_data);
...@@ -1467,8 +1494,8 @@ struct ConvElementwiseaddAct : public PatternBase { ...@@ -1467,8 +1494,8 @@ struct ConvElementwiseaddAct : public PatternBase {
// Conv + ElementwiseAdd + ElementwiseAdd + Activation // Conv + ElementwiseAdd + ElementwiseAdd + Activation
struct ConvElementwiseadd2Act : public PatternBase { struct ConvElementwiseadd2Act : public PatternBase {
ConvElementwiseadd2Act(PDPattern* pattern, const std::string& name_scope) ConvElementwiseadd2Act(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, : PatternBase(
"conv_elementwiseadd2_elementwiseadd_act") {} pattern, name_scope, "conv_elementwiseadd2_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in); PDNode* operator()(PDNode* conv_in);
...@@ -1702,7 +1729,8 @@ struct DequantOpFuse : public PatternBase { ...@@ -1702,7 +1729,8 @@ struct DequantOpFuse : public PatternBase {
DequantOpFuse(PDPattern* pattern, const std::string& name_scope) DequantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "dequant_fuse") {} : PatternBase(pattern, name_scope, "dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& quantized_op_type, void operator()(PDNode* quant_op_input,
const std::string& quantized_op_type,
const std::string& dequant_type, const std::string& dequant_type,
const std::string& weight_name); const std::string& weight_name);
...@@ -1758,8 +1786,8 @@ struct DeleteQuantDequantOpPattern : public PatternBase { ...@@ -1758,8 +1786,8 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
struct DeleteQuantDequantFilterOpPattern : public PatternBase { struct DeleteQuantDequantFilterOpPattern : public PatternBase {
DeleteQuantDequantFilterOpPattern(PDPattern* pattern, DeleteQuantDequantFilterOpPattern(PDPattern* pattern,
const std::string& name_scope) const std::string& name_scope)
: PatternBase(pattern, name_scope, : PatternBase(
"delete_quantdequant_filter_op_pattern") {} pattern, name_scope, "delete_quantdequant_filter_op_pattern") {}
void operator()(); void operator()();
...@@ -1773,7 +1801,8 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase { ...@@ -1773,7 +1801,8 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase {
struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase { struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase {
DeleteWeightQuantDequantLinearOpPattern(PDPattern* pattern, DeleteWeightQuantDequantLinearOpPattern(PDPattern* pattern,
const std::string& name_scope) const std::string& name_scope)
: PatternBase(pattern, name_scope, : PatternBase(pattern,
name_scope,
"delete_weight_quant_dequant_linear_op_pattern") {} "delete_weight_quant_dequant_linear_op_pattern") {}
void operator()(); void operator()();
...@@ -1788,8 +1817,8 @@ struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase { ...@@ -1788,8 +1817,8 @@ struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase {
struct DeleteQuantDequantLinearOpPattern : public PatternBase { struct DeleteQuantDequantLinearOpPattern : public PatternBase {
DeleteQuantDequantLinearOpPattern(PDPattern* pattern, DeleteQuantDequantLinearOpPattern(PDPattern* pattern,
const std::string& name_scope) const std::string& name_scope)
: PatternBase(pattern, name_scope, : PatternBase(
"delete_quant_dequant_linear_op_pattern") {} pattern, name_scope, "delete_quant_dequant_linear_op_pattern") {}
void operator()(); void operator()();
...@@ -1814,7 +1843,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase { ...@@ -1814,7 +1843,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase {
const std::string& name_scope) const std::string& name_scope)
: PatternBase(pattern, name_scope, "reshape_transpose_matmul") {} : PatternBase(pattern, name_scope, "reshape_transpose_matmul") {}
PDNode* operator()(const std::string& op_name, bool with_reshape_xshape, PDNode* operator()(const std::string& op_name,
bool with_reshape_xshape,
bool with_transpose_xshape); bool with_transpose_xshape);
PATTERN_DECL_NODE(reshape_in); PATTERN_DECL_NODE(reshape_in);
......
...@@ -82,9 +82,9 @@ namespace paddle { ...@@ -82,9 +82,9 @@ namespace paddle {
using inference::Singleton; using inference::Singleton;
#if PADDLE_WITH_TENSORRT #if PADDLE_WITH_TENSORRT
using inference::tensorrt::TRTInt8Calibrator;
using inference::tensorrt::TRTCalibratorEngine; using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorEngineManager; using inference::tensorrt::TRTCalibratorEngineManager;
using inference::tensorrt::TRTInt8Calibrator;
#endif #endif
int AnalysisPredictor::clone_num_ = 1; int AnalysisPredictor::clone_num_ = 1;
...@@ -101,7 +101,8 @@ bool IsPersistable(const framework::VarDesc *var) { ...@@ -101,7 +101,8 @@ bool IsPersistable(const framework::VarDesc *var) {
} }
} // namespace } // namespace
bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t, bool PaddleTensorToLoDTensor(const PaddleTensor &pt,
framework::LoDTensor *t,
const platform::Place &place) { const platform::Place &place) {
framework::DDim ddim = phi::make_ddim(pt.shape); framework::DDim ddim = phi::make_ddim(pt.shape);
void *input_ptr; void *input_ptr;
...@@ -129,18 +130,19 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t, ...@@ -129,18 +130,19 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std::memcpy(static_cast<void *>(input_ptr), pt.data.data(), std::memcpy(
pt.data.length()); static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
} else if (platform::is_ipu_place(place)) { } else if (platform::is_ipu_place(place)) {
#ifdef PADDLE_WITH_IPU #ifdef PADDLE_WITH_IPU
std::memcpy(static_cast<void *>(input_ptr), pt.data.data(), std::memcpy(
pt.data.length()); static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
#else #else
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with WITH_IPU, should not reach here.")); "Not compile with WITH_IPU, should not reach here."));
#endif #endif
} else if (platform::is_gpu_place(place)) { } else if (platform::is_gpu_place(place)) {
PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), false, PADDLE_ENFORCE_EQ(platform::is_xpu_place(place),
false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Only one choice can be made between CPU and XPU.")); "Only one choice can be made between CPU and XPU."));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -148,8 +150,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t, ...@@ -148,8 +150,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
auto *dev_ctx = auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(place)); static_cast<const platform::CUDADeviceContext *>(pool.Get(place));
auto dst_gpu_place = place; auto dst_gpu_place = place;
memory::Copy(dst_gpu_place, static_cast<void *>(input_ptr), memory::Copy(dst_gpu_place,
platform::CPUPlace(), pt.data.data(), pt.data.length(), static_cast<void *>(input_ptr),
platform::CPUPlace(),
pt.data.data(),
pt.data.length(),
dev_ctx->stream()); dev_ctx->stream());
#else #else
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
...@@ -158,8 +163,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t, ...@@ -158,8 +163,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
} else if (platform::is_xpu_place(place)) { } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
auto dst_xpu_place = place; auto dst_xpu_place = place;
memory::Copy(dst_xpu_place, static_cast<void *>(input_ptr), memory::Copy(dst_xpu_place,
platform::CPUPlace(), pt.data.data(), pt.data.length()); static_cast<void *>(input_ptr),
platform::CPUPlace(),
pt.data.data(),
pt.data.length());
#else #else
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with XPU, should not reach here.")); "Not compile with XPU, should not reach here."));
...@@ -263,7 +271,8 @@ bool AnalysisPredictor::PrepareProgram( ...@@ -263,7 +271,8 @@ bool AnalysisPredictor::PrepareProgram(
} }
bool AnalysisPredictor::CreateExecutor() { bool AnalysisPredictor::CreateExecutor() {
if (config_.use_gpu()) { if (config_.use_gpu()) {
PADDLE_ENFORCE_EQ(config_.use_xpu(), false, PADDLE_ENFORCE_EQ(config_.use_xpu(),
false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Only one choice can be made between CPU and XPU.")); "Only one choice can be made between CPU and XPU."));
place_ = paddle::platform::CUDAPlace(config_.gpu_device_id()); place_ = paddle::platform::CUDAPlace(config_.gpu_device_id());
...@@ -357,7 +366,8 @@ static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) { ...@@ -357,7 +366,8 @@ static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
} }
static void DisablePrepareDataOpt( static void DisablePrepareDataOpt(
std::shared_ptr<framework::ProgramDesc> inference_program, int block, std::shared_ptr<framework::ProgramDesc> inference_program,
int block,
bool pre_disable_opt) { bool pre_disable_opt) {
bool disable_opt = false; bool disable_opt = false;
auto &infer_block = inference_program->Block(block); auto &infer_block = inference_program->Block(block);
...@@ -367,8 +377,8 @@ static void DisablePrepareDataOpt( ...@@ -367,8 +377,8 @@ static void DisablePrepareDataOpt(
} }
if (op->HasAttr("sub_block")) { if (op->HasAttr("sub_block")) {
int blockID = op->GetBlockAttrId("sub_block"); int blockID = op->GetBlockAttrId("sub_block");
DisablePrepareDataOpt(inference_program, blockID, DisablePrepareDataOpt(
disable_opt || pre_disable_opt); inference_program, blockID, disable_opt || pre_disable_opt);
} }
// disable prepare data if unfriendly op is found // disable prepare data if unfriendly op is found
if (!disable_opt) { if (!disable_opt) {
...@@ -386,8 +396,8 @@ bool AnalysisPredictor::PrepareExecutor() { ...@@ -386,8 +396,8 @@ bool AnalysisPredictor::PrepareExecutor() {
#endif #endif
DisablePrepareDataOpt(inference_program_, 0, false); DisablePrepareDataOpt(inference_program_, 0, false);
executor_->Prepare(sub_scope_, *inference_program_, 0, executor_->Prepare(
config_.use_feed_fetch_ops_); sub_scope_, *inference_program_, 0, config_.use_feed_fetch_ops_);
PADDLE_ENFORCE_NOT_NULL(sub_scope_, PADDLE_ENFORCE_NOT_NULL(sub_scope_,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -433,8 +443,13 @@ bool AnalysisPredictor::PrepareFleetExecutor() { ...@@ -433,8 +443,13 @@ bool AnalysisPredictor::PrepareFleetExecutor() {
feed_fetch_vars.emplace_back(pair.second); feed_fetch_vars.emplace_back(pair.second);
} }
fleet_exe_->Init(config_.dist_config().carrier_id(), fleet_exe_->Init(config_.dist_config().carrier_id(),
*(inference_program_.get()), scope_.get(), place_, 1, *(inference_program_.get()),
{task_node_.get()}, id_to_rank, feed_fetch_vars); scope_.get(),
place_,
1,
{task_node_.get()},
id_to_rank,
feed_fetch_vars);
return true; return true;
} }
...@@ -471,8 +486,12 @@ bool AnalysisPredictor::CommInit() { ...@@ -471,8 +486,12 @@ bool AnalysisPredictor::CommInit() {
peer_endpoints.emplace_back( peer_endpoints.emplace_back(
config_.dist_config().trainer_endpoints()[rank]); config_.dist_config().trainer_endpoints()[rank]);
} }
InsertCommOp(var_name_base + std::to_string(order), ranks_in_group, InsertCommOp(var_name_base + std::to_string(order),
rank_in_group, peer_endpoints, comm_init_block, ring_id); ranks_in_group,
rank_in_group,
peer_endpoints,
comm_init_block,
ring_id);
order += 1; order += 1;
} }
framework::NaiveExecutor e(place_); framework::NaiveExecutor e(place_);
...@@ -484,8 +503,11 @@ bool AnalysisPredictor::CommInit() { ...@@ -484,8 +503,11 @@ bool AnalysisPredictor::CommInit() {
} }
void AnalysisPredictor::InsertCommOp( void AnalysisPredictor::InsertCommOp(
std::string tmp_var_name, int nranks, int rank, std::string tmp_var_name,
const std::vector<std::string> &peer_endpoints, framework::BlockDesc *block, int nranks,
int rank,
const std::vector<std::string> &peer_endpoints,
framework::BlockDesc *block,
int ring_id) { int ring_id) {
/* /*
* tmp_var_name: the var name for var comm_id * tmp_var_name: the var name for var comm_id
...@@ -542,7 +564,8 @@ bool AnalysisPredictor::LoadConverterConfig( ...@@ -542,7 +564,8 @@ bool AnalysisPredictor::LoadConverterConfig(
<< config_.dist_config().comm_init_config() << "\n"; << config_.dist_config().comm_init_config() << "\n";
std::ifstream fin(config_.dist_config().comm_init_config(), std::ios::in); std::ifstream fin(config_.dist_config().comm_init_config(), std::ios::in);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
static_cast<bool>(fin.is_open()), true, static_cast<bool>(fin.is_open()),
true,
platform::errors::NotFound( platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file is normal.", "Cannot open file %s, please confirm whether the file is normal.",
config_.dist_config().comm_init_config())); config_.dist_config().comm_init_config()));
...@@ -686,8 +709,9 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -686,8 +709,9 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
timer.tic(); timer.tic();
// set feed variable // set feed variable
framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get(); framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get();
PADDLE_ENFORCE_NOT_NULL(scope, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(
"The scope should not be nullptr.")); scope,
platform::errors::PreconditionNotMet("The scope should not be nullptr."));
if (!SetFeed(inputs, scope)) { if (!SetFeed(inputs, scope)) {
LOG(ERROR) << "fail to set feed"; LOG(ERROR) << "fail to set feed";
return false; return false;
...@@ -790,9 +814,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -790,9 +814,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
for (size_t i = 0; i < fetches_.size(); ++i) { for (size_t i = 0; i < fetches_.size(); ++i) {
int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col")); int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
static_cast<size_t>(idx), i, static_cast<size_t>(idx),
i,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Fetch op's col attr(%d) should be equal to the index(%d)", idx, "Fetch op's col attr(%d) should be equal to the index(%d)",
idx,
i)); i));
framework::FetchType &fetch_var = framework::FetchType &fetch_var =
framework::GetFetchVariable(*scope, "fetch", idx); framework::GetFetchVariable(*scope, "fetch", idx);
...@@ -833,7 +859,8 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -833,7 +859,8 @@ void AnalysisPredictor::PrepareArgument() {
if (!config_.model_dir().empty()) { if (!config_.model_dir().empty()) {
argument_.SetModelDir(config_.model_dir()); argument_.SetModelDir(config_.model_dir());
} else { } else {
PADDLE_ENFORCE_EQ(config_.prog_file().empty(), false, PADDLE_ENFORCE_EQ(config_.prog_file().empty(),
false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Either model_dir or prog_file should be set.")); "Either model_dir or prog_file should be set."));
std::string dir = inference::analysis::GetDirRoot(config_.prog_file()); std::string dir = inference::analysis::GetDirRoot(config_.prog_file());
...@@ -969,7 +996,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -969,7 +996,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
Analyzer().Run(&argument_); Analyzer().Run(&argument_);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
argument_.scope_valid(), true, argument_.scope_valid(),
true,
platform::errors::InvalidArgument("The argument scope should be valid.")); platform::errors::InvalidArgument("The argument scope should be valid."));
VLOG(5) << "to prepare executor"; VLOG(5) << "to prepare executor";
ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program); ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program);
...@@ -1008,8 +1036,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -1008,8 +1036,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
} }
template <> template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< std::unique_ptr<PaddlePredictor>
AnalysisConfig, PaddleEngineKind::kAnalysis>(const AnalysisConfig &config) { CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
const AnalysisConfig &config) {
// TODO(NHZlX): Should add the link to the doc of // TODO(NHZlX): Should add the link to the doc of
// paddle_infer::CreatePredictor<paddle_infer::Config> // paddle_infer::CreatePredictor<paddle_infer::Config>
if (config.glog_info_disabled()) { if (config.glog_info_disabled()) {
...@@ -1018,7 +1047,8 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< ...@@ -1018,7 +1047,8 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
} }
VLOG(3) << "create AnalysisConfig"; VLOG(3) << "create AnalysisConfig";
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
config.is_valid(), true, config.is_valid(),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Note: Each config can only be used for one predictor.")); "Note: Each config can only be used for one predictor."));
...@@ -1035,11 +1065,13 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< ...@@ -1035,11 +1065,13 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
std::call_once(gflags_initialized, [&]() { std::call_once(gflags_initialized, [&]() {
std::vector<std::string> gflags; std::vector<std::string> gflags;
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
config.memory_pool_init_size_mb(), 0.f, config.memory_pool_init_size_mb(),
0.f,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The size of memory pool should be greater than 0.")); "The size of memory pool should be greater than 0."));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
config.gpu_device_id(), 0, config.gpu_device_id(),
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Invalid device id (%d). The device id should be greater than 0.", "Invalid device id (%d). The device id should be greater than 0.",
config.gpu_device_id())); config.gpu_device_id()));
...@@ -1105,6 +1137,10 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< ...@@ -1105,6 +1137,10 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
config.SetInValid(); config.SetInValid();
auto predictor_p = dynamic_cast<AnalysisPredictor *>(predictor.get()); auto predictor_p = dynamic_cast<AnalysisPredictor *>(predictor.get());
#ifdef PADDLE_WITH_TENSORRT
paddle::framework::ir::patterns::KeyCounter::Instance().CleanCounter();
#endif
if (!predictor_p->Init(nullptr)) { if (!predictor_p->Init(nullptr)) {
return nullptr; return nullptr;
} }
...@@ -1154,8 +1190,9 @@ void AnalysisPredictor::PrepareFeedFetch() { ...@@ -1154,8 +1190,9 @@ void AnalysisPredictor::PrepareFeedFetch() {
} }
void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) { void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) {
PADDLE_ENFORCE_NOT_NULL(scope, platform::errors::InvalidArgument( PADDLE_ENFORCE_NOT_NULL(
"The scope should not be nullptr.")); scope,
platform::errors::InvalidArgument("The scope should not be nullptr."));
auto *var = scope->Var("feed"); auto *var = scope->Var("feed");
var->GetMutable<framework::FeedList>(); var->GetMutable<framework::FeedList>();
var = scope->Var("fetch"); var = scope->Var("fetch");
...@@ -1176,8 +1213,9 @@ AnalysisPredictor::GetInputTensorShape() { ...@@ -1176,8 +1213,9 @@ AnalysisPredictor::GetInputTensorShape() {
std::vector<std::string> names = GetInputNames(); std::vector<std::string> names = GetInputNames();
for (std::string name : names) { for (std::string name : names) {
auto *var = inference_program_->Block(0).FindVar(name); auto *var = inference_program_->Block(0).FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(
"Input %s does not exist.", name)); var,
platform::errors::PreconditionNotMet("Input %s does not exist.", name));
input_shapes[name] = var->GetShape(); input_shapes[name] = var->GetShape();
} }
return input_shapes; return input_shapes;
...@@ -1398,7 +1436,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() { ...@@ -1398,7 +1436,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
std::vector<std::pair<int32_t, int32_t>> counter; std::vector<std::pair<int32_t, int32_t>> counter;
for (auto &it : m) counter.push_back(it); for (auto &it : m) counter.push_back(it);
std::sort( std::sort(
counter.begin(), counter.end(), counter.begin(),
counter.end(),
[](std::pair<int32_t, int32_t> &a, std::pair<int32_t, int32_t> &b) { [](std::pair<int32_t, int32_t> &a, std::pair<int32_t, int32_t> &b) {
return a.second > b.second; return a.second > b.second;
}); });
...@@ -1420,8 +1459,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() { ...@@ -1420,8 +1459,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
opt_shapes[name] = opt_shape; opt_shapes[name] = opt_shape;
} }
inference::SerializeShapeRangeInfo(config_.shape_range_info_path(), inference::SerializeShapeRangeInfo(
min_shapes, max_shapes, opt_shapes); config_.shape_range_info_path(), min_shapes, max_shapes, opt_shapes);
} }
bool AnalysisPredictor::LoadProgramDesc() { bool AnalysisPredictor::LoadProgramDesc() {
...@@ -1441,7 +1480,8 @@ bool AnalysisPredictor::LoadProgramDesc() { ...@@ -1441,7 +1480,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
return false; return false;
} }
LOG(ERROR) << string::Sprintf( LOG(ERROR) << string::Sprintf(
"not valid model path '%s' or program path '%s'.", config_.model_dir(), "not valid model path '%s' or program path '%s'.",
config_.model_dir(),
config_.params_file()); config_.params_file());
return false; return false;
} }
...@@ -1453,7 +1493,8 @@ bool AnalysisPredictor::LoadProgramDesc() { ...@@ -1453,7 +1493,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
// Read binary // Read binary
std::ifstream fin(filename, std::ios::in | std::ios::binary); std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
static_cast<bool>(fin.is_open()), true, static_cast<bool>(fin.is_open()),
true,
platform::errors::NotFound( platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file is normal.", "Cannot open file %s, please confirm whether the file is normal.",
filename)); filename));
...@@ -1555,7 +1596,8 @@ void AnalysisPredictor::ClearIntermediateTensor() { ...@@ -1555,7 +1596,8 @@ void AnalysisPredictor::ClearIntermediateTensor() {
#if PADDLE_WITH_TENSORRT #if PADDLE_WITH_TENSORRT
bool AnalysisPredictor::SaveTrtCalibToDisk() { bool AnalysisPredictor::SaveTrtCalibToDisk() {
PADDLE_ENFORCE_EQ(config_.tensorrt_engine_enabled(), true, PADDLE_ENFORCE_EQ(config_.tensorrt_engine_enabled(),
true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"This func can be invoked only in trt mode")); "This func can be invoked only in trt mode"));
auto &block = inference_program_->Block(0); auto &block = inference_program_->Block(0);
...@@ -1782,8 +1824,10 @@ Predictor::Predictor(const Config &config) { ...@@ -1782,8 +1824,10 @@ Predictor::Predictor(const Config &config) {
<< "Paddle2ONNX do't support convert the Model, fall back to using " << "Paddle2ONNX do't support convert the Model, fall back to using "
"Paddle Inference."; "Paddle Inference.";
} else { } else {
predictor_ = paddle::CreatePaddlePredictor< predictor_ =
Config, paddle::PaddleEngineKind::kONNXRuntime>(config); paddle::CreatePaddlePredictor<Config,
paddle::PaddleEngineKind::kONNXRuntime>(
config);
return; return;
} }
#else #else
...@@ -1793,8 +1837,10 @@ Predictor::Predictor(const Config &config) { ...@@ -1793,8 +1837,10 @@ Predictor::Predictor(const Config &config) {
"fall back to using Paddle Inference."; "fall back to using Paddle Inference.";
#endif #endif
} }
predictor_ = paddle::CreatePaddlePredictor< predictor_ =
Config, paddle::PaddleEngineKind::kAnalysis>(config); paddle::CreatePaddlePredictor<Config,
paddle::PaddleEngineKind::kAnalysis>(
config);
} }
std::vector<std::string> Predictor::GetInputNames() { std::vector<std::string> Predictor::GetInputNames() {
...@@ -1876,7 +1922,8 @@ std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT ...@@ -1876,7 +1922,8 @@ std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT
namespace services { namespace services {
PredictorPool::PredictorPool(const Config &config, size_t size) { PredictorPool::PredictorPool(const Config &config, size_t size) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
size, 1UL, size,
1UL,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The predictor pool size should be greater than 1, but it's (%d)", "The predictor pool size should be greater than 1, but it's (%d)",
size)); size));
...@@ -1895,9 +1942,11 @@ PredictorPool::PredictorPool(const Config &config, size_t size) { ...@@ -1895,9 +1942,11 @@ PredictorPool::PredictorPool(const Config &config, size_t size) {
Predictor *PredictorPool::Retrive(size_t idx) { Predictor *PredictorPool::Retrive(size_t idx) {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
idx, preds_.size() + 1, idx,
preds_.size() + 1,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"There are (%d) predictors in the pool, but the idx is (%d)", idx, "There are (%d) predictors in the pool, but the idx is (%d)",
idx,
preds_.size() + 1)); preds_.size() + 1));
if (idx == 0) { if (idx == 0) {
return main_pred_.get(); return main_pred_.get();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册