Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
80edd7ef
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
80edd7ef
编写于
8月 31, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable run with fuse pass
上级
a79a77ee
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
28 addition
and
7 deletion
+28
-7
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+27
-6
paddle/fluid/operators/fusion_lstm_op.cc
paddle/fluid/operators/fusion_lstm_op.cc
+1
-1
未找到文件。
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
80edd7ef
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -35,7 +37,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -35,7 +37,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
RetrieveNode
(
"any_node"
));
auto
*
id
=
subgraph
.
at
(
gpd
.
pattern
().
RetrieveNode
(
"any_node"
));
marked_nodes
.
insert
(
id
);
marked_nodes
.
insert
(
id
);
};
};
...
@@ -73,12 +74,31 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -73,12 +74,31 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden_n
->
Name
()});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden_n
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell_n
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell_n
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx_n
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx_n
->
Name
()});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
"blstm_0.tmp_2"
});
op_desc
.
SetOutput
(
"BatchedInput"
,
{
"blstm_0.tmp_2"
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
"blstm_1.tmp_2"
});
op_desc
.
SetAttr
(
"is_reverse"
,
lstm_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"is_reverse"
,
lstm_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
false
);
op_desc
.
SetAttr
(
"use_peepholes"
,
false
);
#define TMP_NAME(x) "at.new.tmp." #x
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
OP_SET_OUT
(
BatchedCell
);
OP_SET_OUT
(
BatchedHidden
);
OP_SET_OUT
(
ReorderedH0
);
OP_SET_OUT
(
ReorderedC0
);
#undef OP_SET_OUT
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable<LoDTensor>()
TMP_NEW
(
BatchedCell
);
TMP_NEW
(
BatchedHidden
);
TMP_NEW
(
ReorderedH0
);
TMP_NEW
(
ReorderedC0
);
#undef TMP_NEW
#undef TMP_NAME
#define LINK_TO(a, b) \
#define LINK_TO(a, b) \
a->outputs.push_back(b); \
a->outputs.push_back(b); \
b->inputs.push_back(a);
b->inputs.push_back(a);
...
@@ -89,7 +109,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -89,7 +109,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
LINK_TO
(
op
,
hidden_n
);
LINK_TO
(
op
,
hidden_n
);
#undef LINK_TO
#undef LINK_TO
return
op
;
return
op
;
};
};
lstm_creator
(
16
,
12
,
14
,
18
,
17
,
22
,
21
,
19
);
lstm_creator
(
16
,
12
,
14
,
18
,
17
,
22
,
21
,
19
);
...
@@ -105,16 +124,18 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
...
@@ -105,16 +124,18 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
for
(
auto
it
=
node
->
inputs
.
begin
();
it
!=
node
->
inputs
.
end
();)
{
for
(
auto
it
=
node
->
inputs
.
begin
();
it
!=
node
->
inputs
.
end
();)
{
if
(
marked_nodes
.
count
(
*
it
))
{
if
(
marked_nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
inputs
.
erase
(
it
);
it
=
const_cast
<
Node
*>
(
node
)
->
inputs
.
erase
(
it
);
}
else
}
else
{
it
++
;
it
++
;
}
}
}
for
(
auto
it
=
node
->
outputs
.
begin
();
it
!=
node
->
outputs
.
end
();)
{
for
(
auto
it
=
node
->
outputs
.
begin
();
it
!=
node
->
outputs
.
end
();)
{
if
(
marked_nodes
.
count
(
*
it
))
{
if
(
marked_nodes
.
count
(
*
it
))
{
it
=
const_cast
<
Node
*>
(
node
)
->
outputs
.
erase
(
it
);
it
=
const_cast
<
Node
*>
(
node
)
->
outputs
.
erase
(
it
);
}
else
}
else
{
it
++
;
it
++
;
}
}
}
}
}
return
graph
;
return
graph
;
}
}
...
...
paddle/fluid/operators/fusion_lstm_op.cc
浏览文件 @
80edd7ef
...
@@ -22,7 +22,7 @@ limitations under the License. */
...
@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool
(
seq_mode
,
fals
e
,
"Use sequence mode"
);
DEFINE_bool
(
seq_mode
,
tru
e
,
"Use sequence mode"
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录