提交 80edd7ef 编写于 作者: T tensor-tang

enable run with fuse pass

上级 a79a77ee
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -35,7 +37,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -35,7 +37,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node")); auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node"));
marked_nodes.insert(id); marked_nodes.insert(id);
}; };
...@@ -73,12 +74,31 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -73,12 +74,31 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
op_desc.SetOutput("Hidden", {hidden_n->Name()}); op_desc.SetOutput("Hidden", {hidden_n->Name()});
op_desc.SetOutput("Cell", {cell_n->Name()}); op_desc.SetOutput("Cell", {cell_n->Name()});
op_desc.SetOutput("XX", {xx_n->Name()}); op_desc.SetOutput("XX", {xx_n->Name()});
op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"}); op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"});
op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", false); op_desc.SetAttr("use_peepholes", false);
#define TMP_NAME(x) "at.new.tmp." #x
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
OP_SET_OUT(ReorderedC0);
#undef OP_SET_OUT
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable<LoDTensor>()
TMP_NEW(BatchedCell);
TMP_NEW(BatchedHidden);
TMP_NEW(ReorderedH0);
TMP_NEW(ReorderedC0);
#undef TMP_NEW
#undef TMP_NAME
#define LINK_TO(a, b) \ #define LINK_TO(a, b) \
a->outputs.push_back(b); \ a->outputs.push_back(b); \
b->inputs.push_back(a); b->inputs.push_back(a);
...@@ -89,7 +109,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -89,7 +109,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
LINK_TO(op, hidden_n); LINK_TO(op, hidden_n);
#undef LINK_TO #undef LINK_TO
return op; return op;
}; };
lstm_creator(16, 12, 14, 18, 17, 22, 21, 19); lstm_creator(16, 12, 14, 18, 17, 22, 21, 19);
...@@ -105,14 +124,16 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -105,14 +124,16 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
for (auto it = node->inputs.begin(); it != node->inputs.end();) { for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (marked_nodes.count(*it)) { if (marked_nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it); it = const_cast<Node*>(node)->inputs.erase(it);
} else } else {
it++; it++;
}
} }
for (auto it = node->outputs.begin(); it != node->outputs.end();) { for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (marked_nodes.count(*it)) { if (marked_nodes.count(*it)) {
it = const_cast<Node*>(node)->outputs.erase(it); it = const_cast<Node*>(node)->outputs.erase(it);
} else } else {
it++; it++;
}
} }
} }
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool(seq_mode, false, "Use sequence mode"); DEFINE_bool(seq_mode, true, "Use sequence mode");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册