未验证 提交 835201bf 编写于 作者: L liu zhengxi 提交者: GitHub

Fix multi-threads memory out of bounds error for passes (#21920) (#22132)

* fix seqconv_eltadd_relu pass during multi-threads predictor, test=develop

* fix attention_lstm_fuse_pass during multi-threads inference, test=develop

* fix embedding_fc_lstm_fuse_pass during multi-threads inference, test=develop

* fix fc_lstm_fuse_pass during multi-threads inference, test=develop

* fix seq_concat_fc_fuse_pass during multi-threads inference, test=develop
上级 5a611afd
......@@ -41,7 +41,7 @@ struct Param {
std::string LSTMOUT = "at.lstmout.new";
};
void PrepareParameters(Graph* graph, const Param& param);
void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op);
void FindWhileOp(Graph* graph) {
GraphPatternDetector gpd;
......@@ -98,7 +98,7 @@ void FindWhileOp(Graph* graph) {
auto* hidden_init = graph->RetrieveNode(8);
auto* lstm_op = graph->CreateOpNode(&op_desc);
PrepareParameters(graph, param);
PrepareParameters(graph, param, lstm_op);
IR_NODE_LINK_TO(X, lstm_op);
IR_NODE_LINK_TO(cell_init, lstm_op);
......@@ -133,20 +133,29 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
const LoDTensor& B_output, const LoDTensor& B_cell,
LoDTensor* out);
void PrepareParameters(Graph* graph, const Param& param) {
void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op) {
// Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto& scope = graph->Get<Scope>(kParamScopeAttr);
// Create new parameters.
// AddInput
scope.Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope.Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope.Var(param.Hidden)->GetMutable<LoDTensor>();
scope.Var(param.Cell)->GetMutable<LoDTensor>();
scope.Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope.Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope.Var(param.LSTMX)->GetMutable<LoDTensor>();
scope.Var(param.LSTMOUT)->GetMutable<LoDTensor>();
// AddOutput
#define IR_NODE(x) \
VarDesc key_##x(param.x); \
key_##x.SetPersistable(false); \
auto* node_##x = graph->CreateVarNode(&key_##x); \
IR_NODE_LINK_TO(lstm_op, node_##x);
IR_NODE(Hidden);
IR_NODE(Cell);
IR_NODE(AttentionedX);
IR_NODE(AttentionFCOut);
IR_NODE(LSTMX);
IR_NODE(LSTMOUT);
#undef IR_NODE
#define GATE_W(name__) \
auto* W_##name__##_w0 = scope.FindVar(#name__ ".w_0"); \
......
......@@ -127,35 +127,24 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
embedding_data, k, weightx_data, n, beta, embeddings_data, n);
op_desc.SetInput("Embeddings", {embeddings});
// Create temp variables.
const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
const std::string BatchedCellPreAct =
patterns::UniqueKey("BatchedCellPreAct");
const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetOutput("Cell", {cell->Name()});
op_desc.SetOutput("XX", {xx->Name()});
op_desc.SetOutput("BatchedGate", {BatchedGate});
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
op_desc.SetOutput("BatchedInput", {BatchedInput});
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto& scope = graph->Get<Scope>(kParamScopeAttr);
// Create temp variables.
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope.Var(x)->GetMutable<LoDTensor>()
op_desc.SetOutput(#x, {x});
OP_SET_OUT(BatchedGate);
OP_SET_OUT(BatchCellPreAct);
OP_SET_OUT(BatchedInput);
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
......@@ -163,11 +152,28 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef OP_SET_OUT
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input, op);
IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h, op);
IR_NODE_LINK_TO(bias, op);
IR_NODE_LINK_TO(op, hidden);
#define IR_NODE(x) \
VarDesc key_##x(x); \
key_##x.SetPersistable(false); \
auto* node_##x = graph->CreateVarNode(&key_##x); \
IR_NODE_LINK_TO(op, node_##x);
IR_NODE(BatchedGate);
IR_NODE(BatchCellPreAct);
IR_NODE(BatchedInput);
IR_NODE(BatchedCell);
IR_NODE(BatchedHidden);
IR_NODE(ReorderedH0);
IR_NODE(ReorderedC0);
#undef IR_NODE
return op;
};
......
......@@ -74,38 +74,25 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetInput("Bias", {new_bias_var});
}
// Create temp variables.
const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
const std::string BatchedCellPreAct =
patterns::UniqueKey("BatchedCellPreAct");
const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
const std::string CheckedCell = patterns::UniqueKey("CheckedCell");
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
scope->Var(CheckedCell)->GetMutable<framework::LoDTensor>();
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetOutput("Cell", {cell->Name()});
op_desc.SetOutput("XX", {xx->Name()});
op_desc.SetOutput("BatchedGate", {BatchedGate});
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
op_desc.SetOutput("BatchedInput", {BatchedInput});
op_desc.SetOutput("CheckedCell", {CheckedCell});
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto& scope = graph->Get<Scope>(kParamScopeAttr);
// Create temp variables.
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope.Var(x)->GetMutable<LoDTensor>()
op_desc.SetOutput(#x, {x});
OP_SET_OUT(BatchedGate);
OP_SET_OUT(BatchedCellPreAct);
OP_SET_OUT(BatchedInput);
OP_SET_OUT(CheckedCell);
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
......@@ -113,11 +100,29 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
#undef OP_SET_OUT
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input, op);
IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h, op);
IR_NODE_LINK_TO(bias, op);
IR_NODE_LINK_TO(op, hidden);
#define IR_NODE(x) \
VarDesc key_##x(x); \
key_##x.SetPersistable(false); \
auto* node_##x = graph->CreateVarNode(&key_##x); \
IR_NODE_LINK_TO(op, node_##x);
IR_NODE(BatchedGate);
IR_NODE(BatchedCellPreAct);
IR_NODE(BatchedInput);
IR_NODE(CheckedCell);
IR_NODE(BatchedCell);
IR_NODE(BatchedHidden);
IR_NODE(ReorderedH0);
IR_NODE(ReorderedC0);
#undef IR_NODE
return op;
};
......
......@@ -214,7 +214,9 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const {
op_desc.SetInput("FCWeight", {fc_w->Name()});
op_desc.SetInput("FCBias", {fc_bias->Name()});
const std::string fc_out_tmp = fc_out->Name() + ".tmp";
param_scope()->Var(fc_out_tmp)->GetMutable<framework::LoDTensor>();
VarDesc fc_out_key(fc_out_tmp);
fc_out_key.SetPersistable(false);
auto* fc_out_node = graph->CreateVarNode(&fc_out_key);
op_desc.SetOutput("FCOut", {fc_out_tmp});
op_desc.SetOutput("Out", {fc_out->Name()});
op_desc.SetAttr("fc_activation", act->Op()->Type());
......@@ -227,6 +229,7 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const {
IR_NODE_LINK_TO(sequence_expand0_in, op_node);
IR_NODE_LINK_TO(sequence_expand1_in, op_node);
IR_NODE_LINK_TO(op_node, fc_out);
IR_NODE_LINK_TO(op_node, fc_out_node);
// Clean nodes.
std::unordered_set<const Node*> marked_nodes;
......
......@@ -42,18 +42,19 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength"));
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto& scope = graph->Get<Scope>(kParamScopeAttr);
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
op_desc.SetOutput("ColMat", {ColMat});
op_desc.SetOutput("Out", {relu_out->Name()});
scope.Var(ColMat)->GetMutable<LoDTensor>();
VarDesc key(ColMat);
key.SetPersistable(false);
auto* key_col_mat = graph->CreateVarNode(&key);
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input, op);
IR_NODE_LINK_TO(seqconv_weight, op);
IR_NODE_LINK_TO(eltadd_bias, op);
IR_NODE_LINK_TO(op, relu_out);
IR_NODE_LINK_TO(op, key_col_mat);
return op;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册