未验证 提交 1cc35f36 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13118 from tensor-tang/optimize/op/fusion_lstm

Optimize fusion lstm batch mode
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
...@@ -94,11 +95,31 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -94,11 +95,31 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetOutput("Hidden", {hidden_n->Name()}); op_desc.SetOutput("Hidden", {hidden_n->Name()});
op_desc.SetOutput("Cell", {cell_n->Name()}); op_desc.SetOutput("Cell", {cell_n->Name()});
op_desc.SetOutput("XX", {xx_n->Name()}); op_desc.SetOutput("XX", {xx_n->Name()});
op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"}); op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"});
op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes")); op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true);
#define TMP_NAME(x) "at.new.tmp." #x
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
OP_SET_OUT(ReorderedC0);
#undef OP_SET_OUT
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable<LoDTensor>()
TMP_NEW(BatchedCell);
TMP_NEW(BatchedHidden);
TMP_NEW(ReorderedH0);
TMP_NEW(ReorderedC0);
#undef TMP_NEW
#undef TMP_NAME
#define LINK_TO(a, b) \ #define LINK_TO(a, b) \
a->outputs.push_back(b); \ a->outputs.push_back(b); \
...@@ -116,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -116,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto fc_no_bias_handler = [&]( auto fc_no_bias_handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
#define GET_NODE(name__) \ #define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \ std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \ auto* name__##n = pattern->RetrieveNode(name__##key); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册