未验证 提交 1ea9971a 编写于 作者: J JingZhuangzhuang 提交者: GitHub

modify graph_pattern to thread_local (#43945)

上级 26187c27
......@@ -28,10 +28,17 @@ using string::Style;
size_t PDPattern::id_ = 0UL;
#ifdef PADDLE_WITH_TENSORRT
namespace patterns {
thread_local std::unordered_map<std::string, size_t> KeyCounter::dic_;
}
#endif
PDNode *PDPattern::NewNode(const std::string &name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(
node_map_.count(name), 0UL,
node_map_.count(name),
0UL,
platform::errors::PreconditionNotMet(
"PDNode's name should be unique, get duplicate [%s]", name));
}
......@@ -45,7 +52,8 @@ PDNode *PDPattern::NewNode(const std::string &name) {
PDNode *PDPattern::NewNode(PDNode::teller_t &&teller, const std::string &name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(
node_map_.count(name), 0UL,
node_map_.count(name),
0UL,
platform::errors::PreconditionNotMet(
"PDNode's name should be unique, get duplicate [%s]", name));
}
......@@ -70,7 +78,9 @@ void PDPattern::AddEdge(PDNode *a, PDNode *b) {
a, platform::errors::NotFound("PDNode %s is not found.", a->name()));
PADDLE_ENFORCE_NOT_NULL(
b, platform::errors::NotFound("PDNode %s is not found.", b->name()));
PADDLE_ENFORCE_NE(a, b, platform::errors::PermissionDenied(
PADDLE_ENFORCE_NE(a,
b,
platform::errors::PermissionDenied(
"Cannot connect the same node in the graph."));
edges_.emplace_back(a, b);
}
......@@ -128,7 +138,8 @@ void GraphPatternDetector::ValidateByNodeRole(
subgraphs->erase(
std::remove_if(
subgraphs->begin(), subgraphs->end(),
subgraphs->begin(),
subgraphs->end(),
[](const GraphPatternDetector::subgraph_t &subgraph) -> bool {
// Collect the inputs and outputs.
std::set<Node *> ios;
......@@ -310,7 +321,8 @@ void GraphPatternDetector::SortSubgraphs(
}
std::sort(
subgraphs->begin(), subgraphs->end(),
subgraphs->begin(),
subgraphs->end(),
[](const GraphPatternDetector::subgraph_t &a,
const GraphPatternDetector::subgraph_t &b) {
for (auto &item : a) {
......@@ -438,7 +450,8 @@ PDNode *PDNode::assert_is_persistable_var() {
}
PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type,
const std::string &argument, int nth) {
const std::string &argument,
int nth) {
assert_is_var();
assert_is_op_input(op_type);
asserts_.emplace_back([=](Node *x) {
......@@ -453,7 +466,8 @@ PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type,
}
PDNode *PDNode::assert_is_op_nth_output(const std::string &op_type,
const std::string &argument, int nth) {
const std::string &argument,
int nth) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) {
......@@ -580,7 +594,8 @@ PDNode *PDNode::assert_is_ops(const std::unordered_set<std::string> &op_types) {
PDNode *PDNode::assert_is_ops_nth_input(
const std::unordered_set<std::string> &op_types,
const std::string &argument, int nth) {
const std::string &argument,
int nth) {
assert_is_var();
assert_is_ops_input(op_types);
asserts_.emplace_back([=](Node *x) {
......@@ -596,7 +611,8 @@ PDNode *PDNode::assert_is_ops_nth_input(
PDNode *PDNode::assert_is_ops_nth_output(
const std::unordered_set<std::string> &op_types,
const std::string &argument, int nth) {
const std::string &argument,
int nth) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) {
......@@ -693,11 +709,13 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE_EQ(
var->IsVar(), true,
var->IsVar(),
true,
platform::errors::InvalidArgument(
"First parameter of function IsNthInput must be Node::Var"));
PADDLE_ENFORCE_EQ(
op->IsOp(), true,
op->IsOp(),
true,
platform::errors::InvalidArgument(
"Second parameter of function IsNthInput must be Node::Op"));
if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth)
......@@ -707,7 +725,8 @@ bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
bool HasInput(Node *op, const std::string &argument) {
PADDLE_ENFORCE_EQ(
op->IsOp(), true,
op->IsOp(),
true,
platform::errors::InvalidArgument(
"First parameter of function HasInput must be Node::Op"));
auto const &names = op->Op()->InputNames();
......@@ -718,7 +737,8 @@ bool HasInput(Node *op, const std::string &argument) {
bool HasOutput(Node *op, const std::string &argument) {
PADDLE_ENFORCE_EQ(
op->IsOp(), true,
op->IsOp(),
true,
platform::errors::InvalidArgument(
"First parameter of function HasOuput must be Node::Op"));
auto const &names = op->Op()->OutputNames();
......@@ -729,11 +749,13 @@ bool HasOutput(Node *op, const std::string &argument) {
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE_EQ(
var->IsVar(), true,
var->IsVar(),
true,
platform::errors::InvalidArgument(
"First parameter of function IsNthOutput must be Node::Var"));
PADDLE_ENFORCE_EQ(
op->IsOp(), true,
op->IsOp(),
true,
platform::errors::InvalidArgument(
"Second parameter of function IsNthOutput must be Node::Op"));
if (!HasOutput(op, argument) || op->Op()->Output(argument).size() <= nth)
......@@ -875,22 +897,35 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
eltwise_op->LinksFrom({conv_out_var, eltwise_y_in_var})
.LinksTo({eltwise_out_var});
batch_norm_op
->LinksFrom({eltwise_out_var, bn_scale_var, bn_bias_var, bn_mean_var,
->LinksFrom({eltwise_out_var,
bn_scale_var,
bn_bias_var,
bn_mean_var,
bn_variance_var})
.LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var,
bn_saved_mean_var, bn_saved_variance_var});
.LinksTo({bn_out_var,
bn_mean_out_var,
bn_variance_out_var,
bn_saved_mean_var,
bn_saved_variance_var});
} else {
batch_norm_op
->LinksFrom({conv_out_var, bn_scale_var, bn_bias_var, bn_mean_var,
->LinksFrom({conv_out_var,
bn_scale_var,
bn_bias_var,
bn_mean_var,
bn_variance_var})
.LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var,
bn_saved_mean_var, bn_saved_variance_var});
.LinksTo({bn_out_var,
bn_mean_out_var,
bn_variance_out_var,
bn_saved_mean_var,
bn_saved_variance_var});
}
return bn_out_var;
}
PDNode *patterns::ConvActivation::operator()(
paddle::framework::ir::PDNode *conv_input, std::string conv_type,
paddle::framework::ir::PDNode *conv_input,
std::string conv_type,
std::string activation_type) {
// Create Operators
conv_input->assert_is_op_input(conv_type, "Input");
......@@ -920,7 +955,8 @@ PDNode *patterns::ConvActivation::operator()(
PDNode *patterns::ElementwiseActivation::operator()(
paddle::framework::ir::PDNode *elementwise_a,
const std::string &elementwise_type, const std::string &activation_type) {
const std::string &elementwise_type,
const std::string &activation_type) {
// Create Operators
elementwise_a->assert_is_op_input(elementwise_type, "X");
auto *elementwise_op =
......@@ -995,7 +1031,8 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
}
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
bool with_bias, bool with_relu) {
bool with_bias,
bool with_relu) {
// Create shared nodes.
x->assert_is_op_input("mul", "X");
auto *mul = pattern->NewNode(mul_repr())->assert_is_op("mul");
......@@ -1261,8 +1298,12 @@ PDNode *patterns::BatchNormAct::operator()(
bn->LinksFrom(
{bn_x_var, bn_scale_var, bn_bias_var, bn_variance_var, bn_mean_var})
.LinksTo({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var,
bn_saved_mean_var, bn_reserve_space, bn_out_var});
.LinksTo({bn_mean_out_var,
bn_variance_out_var,
bn_saved_variance_var,
bn_saved_mean_var,
bn_reserve_space,
bn_out_var});
act->LinksFrom({bn_out_var}).LinksTo({act_out_var});
return act_out_var;
......@@ -1319,8 +1360,13 @@ PDNode *patterns::BatchNormActGrad::operator()(
.LinksTo({d_intermediate_var});
bn_grad
->LinksFrom({bn_x_var, d_intermediate_var, bn_scale_var, bn_bias_var,
bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space})
->LinksFrom({bn_x_var,
d_intermediate_var,
bn_scale_var,
bn_bias_var,
bn_saved_mean_var,
bn_saved_variance_var,
bn_reserve_space})
.LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var});
return bn_grad;
......@@ -1404,8 +1450,12 @@ PDNode *patterns::BatchNormAddAct::operator()(
pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out");
bn->LinksFrom({bn_x_var, bn_scale_var, bn_bias_var})
.LinksTo({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var,
bn_saved_mean_var, bn_reserve_space, bn_out_var});
.LinksTo({bn_mean_out_var,
bn_variance_out_var,
bn_saved_variance_var,
bn_saved_mean_var,
bn_reserve_space,
bn_out_var});
elewise_add->LinksFrom({elewise_add_in_var, bn_out_var})
.LinksTo({elewise_add_out_var});
act->LinksFrom({elewise_add_out_var}).LinksTo({act_out_var});
......@@ -1484,8 +1534,13 @@ PDNode *patterns::BatchNormAddActGrad::operator()(
.LinksTo({d_elewise_add_in_var, d_bn_out_var});
bn_grad
->LinksFrom({bn_x_var, d_bn_out_var, bn_scale_var, bn_bias_var,
bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space})
->LinksFrom({bn_x_var,
d_bn_out_var,
bn_scale_var,
bn_bias_var,
bn_saved_mean_var,
bn_saved_variance_var,
bn_reserve_space})
.LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var});
return bn_grad;
......@@ -1558,7 +1613,8 @@ PDNode *patterns::ElewiseAddAct::operator()(
PDNode *patterns::LinearAct::operator()(
paddle::framework::ir::PDNode *linear_x_var,
const std::unordered_set<std::string> &act_types, bool with_grad_link,
const std::unordered_set<std::string> &act_types,
bool with_grad_link,
bool is_act_grad_x_from_act) {
auto *matmul_w_var =
pattern->NewNode(matmul_w_repr())->assert_is_op_input("matmul_v2", "Y");
......@@ -1621,7 +1677,8 @@ PDNode *patterns::LinearAct::operator()(
PDNode *patterns::ElewiseAddMatmulAct::operator()(
paddle::framework::ir::PDNode *dout_var,
const std::unordered_set<std::string> &act_grad_types,
bool without_x_gradient, bool is_act_grad_x_from_act) {
bool without_x_gradient,
bool is_act_grad_x_from_act) {
auto *ele_grad_bias_var =
pattern->NewNode(ele_grad_bias_repr())
->assert_is_op_input("elementwise_add_grad", "Y");
......@@ -2052,7 +2109,8 @@ PDNode *patterns::Pool::operator()() {
return output_var;
}
PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
PDNode *patterns::Elementwise::operator()(PDNode *x_var,
PDNode *y_var,
const std::string elementwise_type) {
auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
......@@ -2084,7 +2142,9 @@ PDNode *patterns::ElementwiseOp::operator()(
}
PDNode *patterns::ResidualElementwise::operator()(
PDNode *op_var, PDNode *residual_var, const std::string elementwise_type,
PDNode *op_var,
PDNode *residual_var,
const std::string elementwise_type,
bool as_x) {
auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
......@@ -3065,7 +3125,8 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
}
PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
const std::string &op_name, bool with_reshape_xshape,
const std::string &op_name,
bool with_reshape_xshape,
bool with_transpose_xshape) {
auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
......@@ -3098,8 +3159,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
transpose_out->assert_is_only_output_of_op("transpose2");
auto transpose_xshape =
with_transpose_xshape
? pattern->NewNode(transpose_xshape_repr())
with_transpose_xshape ? pattern->NewNode(transpose_xshape_repr())
->AsIntermediate()
->assert_is_op_output("transpose2", "XShape")
: nullptr;
......
......@@ -122,10 +122,12 @@ struct PDNode {
PDNode* assert_is_op_input(const std::string& op_type,
const std::string& argument);
PDNode* assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth);
const std::string& argument,
int nth);
PDNode* assert_is_not_op_input(const std::string& argument);
PDNode* assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth);
const std::string& argument,
int nth);
PDNode* assert_is_only_input_of_op(const std::string& op_type);
PDNode* assert_is_only_output_of_op(const std::string& op_type);
PDNode* assert_op_has_n_inputs(const std::string& op_type, size_t n);
......@@ -138,13 +140,15 @@ struct PDNode {
const std::string& argument);
PDNode* assert_is_ops_nth_input(
const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth);
const std::string& argument,
int nth);
PDNode* assert_is_ops_input(const std::unordered_set<std::string>& op_types);
PDNode* assert_is_ops_input(const std::unordered_set<std::string>& op_types,
const std::string& argument);
PDNode* assert_is_ops_nth_output(
const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth);
const std::string& argument,
int nth);
PDNode* assert_is_only_input_of_ops(
const std::unordered_set<std::string>& op_types);
......@@ -164,10 +168,13 @@ struct PDNode {
}
private:
PDNode(PDPattern* pattern, const std::string& name = "",
PDNode(PDPattern* pattern,
const std::string& name = "",
Type type = Type::kVar)
: pattern_(pattern), name_(name), type_(type) {}
PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "",
PDNode(teller_t&& teller,
PDPattern* pattern,
const std::string& name = "",
Type type = Type::kVar)
: teller_(std::move(teller)),
pattern_(pattern),
......@@ -398,16 +405,25 @@ struct KeyCounter {
return x;
}
#ifdef PADDLE_WITH_TENSORRT
static int IncCounter(const std::string& key) { return dic_[key]++; }
static void CleanCounter() { dic_.clear(); }
private:
static thread_local std::unordered_map<std::string, size_t> dic_;
#else
int IncCounter(const std::string& key) { return dic_[key]++; }
private:
std::unordered_map<std::string, size_t> dic_;
#endif
};
// Generate a unique PDNode's name with name_scope and id.
// The format is {name_scope}/{repr}/{id}/{name}
static std::string PDNodeName(const std::string& name_scope,
const std::string& repr, size_t id,
const std::string& repr,
size_t id,
const std::string& name) {
return string::Sprintf("%s/%s/%d/%s", name_scope, repr, id, name);
}
......@@ -415,15 +431,15 @@ static std::string PDNodeName(const std::string& name_scope,
// The format is {name_scope}/{repr}/{id}
static std::string PDNodeName(const std::string& name_scope,
const std::string& repr) {
return string::Sprintf("%s/%s/%d", name_scope, repr,
KeyCounter::Instance().IncCounter(repr));
return string::Sprintf(
"%s/%s/%d", name_scope, repr, KeyCounter::Instance().IncCounter(repr));
}
// Generate a unique key. It can be used for a universally unique temporary
// name.
// The format is {repr}/{id}
static std::string UniqueKey(const std::string& repr) {
return string::Sprintf("%s/%d", repr,
KeyCounter::Instance().IncCounter(repr));
return string::Sprintf(
"%s/%d", repr, KeyCounter::Instance().IncCounter(repr));
}
// Declare a PDNode in a pattern, will create two methods:
......@@ -440,17 +456,19 @@ static std::string UniqueKey(const std::string& repr) {
// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition.
// pat: the pattern object.
#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \
PADDLE_ENFORCE_NE(subgraph.count(pat.arg##_n()), 0UL, \
PADDLE_ENFORCE_NE(subgraph.count(pat.arg##_n()), \
0UL, \
platform::errors::NotFound("Node not found for PDNode %s", \
pat.arg##_repr())); \
Node* var = subgraph.at(pat.arg##_n()); \
PADDLE_ENFORCE_NOT_NULL( \
var, platform::errors::NotFound("node %s not exists in the sub-graph", \
#arg));
PADDLE_ENFORCE_NOT_NULL(var, \
platform::errors::NotFound( \
"node %s not exists in the sub-graph", #arg));
// The base class of all the patterns.
struct PatternBase {
PatternBase(PDPattern* pattern, const std::string& name_scope,
PatternBase(PDPattern* pattern,
const std::string& name_scope,
const std::string& repr)
: pattern(pattern),
name_scope_(name_scope),
......@@ -476,7 +494,8 @@ struct ConvBN : public PatternBase {
ConvBN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bn") {}
PDNode* operator()(PDNode* conv_input, const std::string& conv_type,
PDNode* operator()(PDNode* conv_input,
const std::string& conv_type,
bool with_eltwise_add);
// declare operator node's name
......@@ -514,7 +533,8 @@ struct ConvActivation : public PatternBase {
ConvActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_activation") {}
PDNode* operator()(PDNode* conv_input, std::string conv_type = "conv2d",
PDNode* operator()(PDNode* conv_input,
std::string conv_type = "conv2d",
std::string activation_type = "relu");
// declare operator node's name
......@@ -536,7 +556,8 @@ struct ElementwiseActivation : public PatternBase {
ElementwiseActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add_activation") {}
PDNode* operator()(PDNode* elementwise_a, const std::string& elementwise_type,
PDNode* operator()(PDNode* elementwise_a,
const std::string& elementwise_type,
const std::string& activation_type);
// declare operator node's name
......@@ -936,7 +957,8 @@ struct LinearAct : public PatternBase {
PDNode* operator()(PDNode* x,
const std::unordered_set<std::string>& act_types,
bool with_grad_link, bool is_act_grad_x_from_act);
bool with_grad_link,
bool is_act_grad_x_from_act);
// declare operator node's name
PATTERN_DECL_NODE(matmul);
......@@ -965,7 +987,8 @@ struct ElewiseAddMatmulAct : public PatternBase {
PDNode* operator()(PDNode* x,
const std::unordered_set<std::string>& act_grad_types,
bool without_x_gradient, bool is_act_grad_x_from_act);
bool without_x_gradient,
bool is_act_grad_x_from_act);
// declare operator node's name
PATTERN_DECL_NODE(ele_add_grad);
......@@ -1062,7 +1085,8 @@ struct Elementwise : public PatternBase {
Elementwise(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise") {}
PDNode* operator()(PDNode* x_var, PDNode* y_var,
PDNode* operator()(PDNode* x_var,
PDNode* y_var,
const std::string elementwise_type);
PATTERN_DECL_NODE(elementwise_op);
......@@ -1088,11 +1112,14 @@ struct ElementwiseOp : public PatternBase {
// This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag
struct ResidualElementwise : public PatternBase {
ResidualElementwise(PDPattern* pattern, const std::string& name_scope,
ResidualElementwise(PDPattern* pattern,
const std::string& name_scope,
bool as_x)
: PatternBase(pattern, name_scope, "residual_elementwise") {}
PDNode* operator()(PDNode* op_var, PDNode* residual_var,
const std::string elementwise_type, bool as_x);
PDNode* operator()(PDNode* op_var,
PDNode* residual_var,
const std::string elementwise_type,
bool as_x);
PATTERN_DECL_NODE(operator_output);
PATTERN_DECL_NODE(residual_data);
......@@ -1467,8 +1494,8 @@ struct ConvElementwiseaddAct : public PatternBase {
// Conv + ElementwiseAdd + ElementwiseAdd + Activation
struct ConvElementwiseadd2Act : public PatternBase {
ConvElementwiseadd2Act(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope,
"conv_elementwiseadd2_elementwiseadd_act") {}
: PatternBase(
pattern, name_scope, "conv_elementwiseadd2_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in);
......@@ -1702,7 +1729,8 @@ struct DequantOpFuse : public PatternBase {
DequantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& quantized_op_type,
void operator()(PDNode* quant_op_input,
const std::string& quantized_op_type,
const std::string& dequant_type,
const std::string& weight_name);
......@@ -1758,8 +1786,8 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
struct DeleteQuantDequantFilterOpPattern : public PatternBase {
DeleteQuantDequantFilterOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope,
"delete_quantdequant_filter_op_pattern") {}
: PatternBase(
pattern, name_scope, "delete_quantdequant_filter_op_pattern") {}
void operator()();
......@@ -1773,7 +1801,8 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase {
struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase {
DeleteWeightQuantDequantLinearOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope,
: PatternBase(pattern,
name_scope,
"delete_weight_quant_dequant_linear_op_pattern") {}
void operator()();
......@@ -1788,8 +1817,8 @@ struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase {
struct DeleteQuantDequantLinearOpPattern : public PatternBase {
DeleteQuantDequantLinearOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope,
"delete_quant_dequant_linear_op_pattern") {}
: PatternBase(
pattern, name_scope, "delete_quant_dequant_linear_op_pattern") {}
void operator()();
......@@ -1814,7 +1843,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase {
const std::string& name_scope)
: PatternBase(pattern, name_scope, "reshape_transpose_matmul") {}
PDNode* operator()(const std::string& op_name, bool with_reshape_xshape,
PDNode* operator()(const std::string& op_name,
bool with_reshape_xshape,
bool with_transpose_xshape);
PATTERN_DECL_NODE(reshape_in);
......
......@@ -82,9 +82,9 @@ namespace paddle {
using inference::Singleton;
#if PADDLE_WITH_TENSORRT
using inference::tensorrt::TRTInt8Calibrator;
using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorEngineManager;
using inference::tensorrt::TRTInt8Calibrator;
#endif
int AnalysisPredictor::clone_num_ = 1;
......@@ -101,7 +101,8 @@ bool IsPersistable(const framework::VarDesc *var) {
}
} // namespace
bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
bool PaddleTensorToLoDTensor(const PaddleTensor &pt,
framework::LoDTensor *t,
const platform::Place &place) {
framework::DDim ddim = phi::make_ddim(pt.shape);
void *input_ptr;
......@@ -129,18 +130,19 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
if (platform::is_cpu_place(place)) {
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std::memcpy(static_cast<void *>(input_ptr), pt.data.data(),
pt.data.length());
std::memcpy(
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
} else if (platform::is_ipu_place(place)) {
#ifdef PADDLE_WITH_IPU
std::memcpy(static_cast<void *>(input_ptr), pt.data.data(),
pt.data.length());
std::memcpy(
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with WITH_IPU, should not reach here."));
#endif
} else if (platform::is_gpu_place(place)) {
PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), false,
PADDLE_ENFORCE_EQ(platform::is_xpu_place(place),
false,
platform::errors::InvalidArgument(
"Only one choice can be made between CPU and XPU."));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -148,8 +150,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(place));
auto dst_gpu_place = place;
memory::Copy(dst_gpu_place, static_cast<void *>(input_ptr),
platform::CPUPlace(), pt.data.data(), pt.data.length(),
memory::Copy(dst_gpu_place,
static_cast<void *>(input_ptr),
platform::CPUPlace(),
pt.data.data(),
pt.data.length(),
dev_ctx->stream());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
......@@ -158,8 +163,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
auto dst_xpu_place = place;
memory::Copy(dst_xpu_place, static_cast<void *>(input_ptr),
platform::CPUPlace(), pt.data.data(), pt.data.length());
memory::Copy(dst_xpu_place,
static_cast<void *>(input_ptr),
platform::CPUPlace(),
pt.data.data(),
pt.data.length());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with XPU, should not reach here."));
......@@ -263,7 +271,8 @@ bool AnalysisPredictor::PrepareProgram(
}
bool AnalysisPredictor::CreateExecutor() {
if (config_.use_gpu()) {
PADDLE_ENFORCE_EQ(config_.use_xpu(), false,
PADDLE_ENFORCE_EQ(config_.use_xpu(),
false,
platform::errors::InvalidArgument(
"Only one choice can be made between CPU and XPU."));
place_ = paddle::platform::CUDAPlace(config_.gpu_device_id());
......@@ -357,7 +366,8 @@ static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
}
static void DisablePrepareDataOpt(
std::shared_ptr<framework::ProgramDesc> inference_program, int block,
std::shared_ptr<framework::ProgramDesc> inference_program,
int block,
bool pre_disable_opt) {
bool disable_opt = false;
auto &infer_block = inference_program->Block(block);
......@@ -367,8 +377,8 @@ static void DisablePrepareDataOpt(
}
if (op->HasAttr("sub_block")) {
int blockID = op->GetBlockAttrId("sub_block");
DisablePrepareDataOpt(inference_program, blockID,
disable_opt || pre_disable_opt);
DisablePrepareDataOpt(
inference_program, blockID, disable_opt || pre_disable_opt);
}
// disable prepare data if unfriendly op is found
if (!disable_opt) {
......@@ -386,8 +396,8 @@ bool AnalysisPredictor::PrepareExecutor() {
#endif
DisablePrepareDataOpt(inference_program_, 0, false);
executor_->Prepare(sub_scope_, *inference_program_, 0,
config_.use_feed_fetch_ops_);
executor_->Prepare(
sub_scope_, *inference_program_, 0, config_.use_feed_fetch_ops_);
PADDLE_ENFORCE_NOT_NULL(sub_scope_,
platform::errors::PreconditionNotMet(
......@@ -433,8 +443,13 @@ bool AnalysisPredictor::PrepareFleetExecutor() {
feed_fetch_vars.emplace_back(pair.second);
}
fleet_exe_->Init(config_.dist_config().carrier_id(),
*(inference_program_.get()), scope_.get(), place_, 1,
{task_node_.get()}, id_to_rank, feed_fetch_vars);
*(inference_program_.get()),
scope_.get(),
place_,
1,
{task_node_.get()},
id_to_rank,
feed_fetch_vars);
return true;
}
......@@ -471,8 +486,12 @@ bool AnalysisPredictor::CommInit() {
peer_endpoints.emplace_back(
config_.dist_config().trainer_endpoints()[rank]);
}
InsertCommOp(var_name_base + std::to_string(order), ranks_in_group,
rank_in_group, peer_endpoints, comm_init_block, ring_id);
InsertCommOp(var_name_base + std::to_string(order),
ranks_in_group,
rank_in_group,
peer_endpoints,
comm_init_block,
ring_id);
order += 1;
}
framework::NaiveExecutor e(place_);
......@@ -484,8 +503,11 @@ bool AnalysisPredictor::CommInit() {
}
void AnalysisPredictor::InsertCommOp(
std::string tmp_var_name, int nranks, int rank,
const std::vector<std::string> &peer_endpoints, framework::BlockDesc *block,
std::string tmp_var_name,
int nranks,
int rank,
const std::vector<std::string> &peer_endpoints,
framework::BlockDesc *block,
int ring_id) {
/*
* tmp_var_name: the var name for var comm_id
......@@ -542,7 +564,8 @@ bool AnalysisPredictor::LoadConverterConfig(
<< config_.dist_config().comm_init_config() << "\n";
std::ifstream fin(config_.dist_config().comm_init_config(), std::ios::in);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin.is_open()), true,
static_cast<bool>(fin.is_open()),
true,
platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file is normal.",
config_.dist_config().comm_init_config()));
......@@ -686,8 +709,9 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
timer.tic();
// set feed variable
framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get();
PADDLE_ENFORCE_NOT_NULL(scope, platform::errors::PreconditionNotMet(
"The scope should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::PreconditionNotMet("The scope should not be nullptr."));
if (!SetFeed(inputs, scope)) {
LOG(ERROR) << "fail to set feed";
return false;
......@@ -790,9 +814,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
for (size_t i = 0; i < fetches_.size(); ++i) {
int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col"));
PADDLE_ENFORCE_EQ(
static_cast<size_t>(idx), i,
static_cast<size_t>(idx),
i,
platform::errors::InvalidArgument(
"Fetch op's col attr(%d) should be equal to the index(%d)", idx,
"Fetch op's col attr(%d) should be equal to the index(%d)",
idx,
i));
framework::FetchType &fetch_var =
framework::GetFetchVariable(*scope, "fetch", idx);
......@@ -833,7 +859,8 @@ void AnalysisPredictor::PrepareArgument() {
if (!config_.model_dir().empty()) {
argument_.SetModelDir(config_.model_dir());
} else {
PADDLE_ENFORCE_EQ(config_.prog_file().empty(), false,
PADDLE_ENFORCE_EQ(config_.prog_file().empty(),
false,
platform::errors::PreconditionNotMet(
"Either model_dir or prog_file should be set."));
std::string dir = inference::analysis::GetDirRoot(config_.prog_file());
......@@ -969,7 +996,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
Analyzer().Run(&argument_);
PADDLE_ENFORCE_EQ(
argument_.scope_valid(), true,
argument_.scope_valid(),
true,
platform::errors::InvalidArgument("The argument scope should be valid."));
VLOG(5) << "to prepare executor";
ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program);
......@@ -1008,8 +1036,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
}
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
AnalysisConfig, PaddleEngineKind::kAnalysis>(const AnalysisConfig &config) {
std::unique_ptr<PaddlePredictor>
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
const AnalysisConfig &config) {
// TODO(NHZlX): Should add the link to the doc of
// paddle_infer::CreatePredictor<paddle_infer::Config>
if (config.glog_info_disabled()) {
......@@ -1018,7 +1047,8 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
}
VLOG(3) << "create AnalysisConfig";
PADDLE_ENFORCE_EQ(
config.is_valid(), true,
config.is_valid(),
true,
platform::errors::InvalidArgument(
"Note: Each config can only be used for one predictor."));
......@@ -1035,11 +1065,13 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
std::call_once(gflags_initialized, [&]() {
std::vector<std::string> gflags;
PADDLE_ENFORCE_GE(
config.memory_pool_init_size_mb(), 0.f,
config.memory_pool_init_size_mb(),
0.f,
platform::errors::InvalidArgument(
"The size of memory pool should be greater than 0."));
PADDLE_ENFORCE_GE(
config.gpu_device_id(), 0,
config.gpu_device_id(),
0,
platform::errors::InvalidArgument(
"Invalid device id (%d). The device id should be greater than 0.",
config.gpu_device_id()));
......@@ -1105,6 +1137,10 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
config.SetInValid();
auto predictor_p = dynamic_cast<AnalysisPredictor *>(predictor.get());
#ifdef PADDLE_WITH_TENSORRT
paddle::framework::ir::patterns::KeyCounter::Instance().CleanCounter();
#endif
if (!predictor_p->Init(nullptr)) {
return nullptr;
}
......@@ -1154,8 +1190,9 @@ void AnalysisPredictor::PrepareFeedFetch() {
}
void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) {
PADDLE_ENFORCE_NOT_NULL(scope, platform::errors::InvalidArgument(
"The scope should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::InvalidArgument("The scope should not be nullptr."));
auto *var = scope->Var("feed");
var->GetMutable<framework::FeedList>();
var = scope->Var("fetch");
......@@ -1176,8 +1213,9 @@ AnalysisPredictor::GetInputTensorShape() {
std::vector<std::string> names = GetInputNames();
for (std::string name : names) {
auto *var = inference_program_->Block(0).FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet(
"Input %s does not exist.", name));
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::PreconditionNotMet("Input %s does not exist.", name));
input_shapes[name] = var->GetShape();
}
return input_shapes;
......@@ -1398,7 +1436,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
std::vector<std::pair<int32_t, int32_t>> counter;
for (auto &it : m) counter.push_back(it);
std::sort(
counter.begin(), counter.end(),
counter.begin(),
counter.end(),
[](std::pair<int32_t, int32_t> &a, std::pair<int32_t, int32_t> &b) {
return a.second > b.second;
});
......@@ -1420,8 +1459,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
opt_shapes[name] = opt_shape;
}
inference::SerializeShapeRangeInfo(config_.shape_range_info_path(),
min_shapes, max_shapes, opt_shapes);
inference::SerializeShapeRangeInfo(
config_.shape_range_info_path(), min_shapes, max_shapes, opt_shapes);
}
bool AnalysisPredictor::LoadProgramDesc() {
......@@ -1441,7 +1480,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
return false;
}
LOG(ERROR) << string::Sprintf(
"not valid model path '%s' or program path '%s'.", config_.model_dir(),
"not valid model path '%s' or program path '%s'.",
config_.model_dir(),
config_.params_file());
return false;
}
......@@ -1453,7 +1493,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
// Read binary
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin.is_open()), true,
static_cast<bool>(fin.is_open()),
true,
platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file is normal.",
filename));
......@@ -1555,7 +1596,8 @@ void AnalysisPredictor::ClearIntermediateTensor() {
#if PADDLE_WITH_TENSORRT
bool AnalysisPredictor::SaveTrtCalibToDisk() {
PADDLE_ENFORCE_EQ(config_.tensorrt_engine_enabled(), true,
PADDLE_ENFORCE_EQ(config_.tensorrt_engine_enabled(),
true,
platform::errors::PreconditionNotMet(
"This func can be invoked only in trt mode"));
auto &block = inference_program_->Block(0);
......@@ -1782,8 +1824,10 @@ Predictor::Predictor(const Config &config) {
<< "Paddle2ONNX do't support convert the Model, fall back to using "
"Paddle Inference.";
} else {
predictor_ = paddle::CreatePaddlePredictor<
Config, paddle::PaddleEngineKind::kONNXRuntime>(config);
predictor_ =
paddle::CreatePaddlePredictor<Config,
paddle::PaddleEngineKind::kONNXRuntime>(
config);
return;
}
#else
......@@ -1793,8 +1837,10 @@ Predictor::Predictor(const Config &config) {
"fall back to using Paddle Inference.";
#endif
}
predictor_ = paddle::CreatePaddlePredictor<
Config, paddle::PaddleEngineKind::kAnalysis>(config);
predictor_ =
paddle::CreatePaddlePredictor<Config,
paddle::PaddleEngineKind::kAnalysis>(
config);
}
std::vector<std::string> Predictor::GetInputNames() {
......@@ -1876,7 +1922,8 @@ std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT
namespace services {
PredictorPool::PredictorPool(const Config &config, size_t size) {
PADDLE_ENFORCE_GE(
size, 1UL,
size,
1UL,
paddle::platform::errors::InvalidArgument(
"The predictor pool size should be greater than 1, but it's (%d)",
size));
......@@ -1895,9 +1942,11 @@ PredictorPool::PredictorPool(const Config &config, size_t size) {
Predictor *PredictorPool::Retrive(size_t idx) {
PADDLE_ENFORCE_LT(
idx, preds_.size() + 1,
idx,
preds_.size() + 1,
paddle::platform::errors::InvalidArgument(
"There are (%d) predictors in the pool, but the idx is (%d)", idx,
"There are (%d) predictors in the pool, but the idx is (%d)",
idx,
preds_.size() + 1));
if (idx == 0) {
return main_pred_.get();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册