Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fd7fd0c4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fd7fd0c4
编写于
9月 27, 2020
作者:
Z
zhhsplendid
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2stat] Fix lstm bug, test=develop
上级
9b7ebf10
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
71 addition
and
2 deletion
+71
-2
python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
...ddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
+14
-1
python/paddle/fluid/layers/rnn.py
python/paddle/fluid/layers/rnn.py
+1
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py
...ddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py
+56
-0
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
浏览文件 @
fd7fd0c4
...
@@ -65,7 +65,20 @@ def is_unsupported(func):
...
@@ -65,7 +65,20 @@ def is_unsupported(func):
Checks whether the func is supported by dygraph to static graph.
Checks whether the func is supported by dygraph to static graph.
"""
"""
if
any
(
func
in
m
.
__dict__
.
values
()
for
m
in
BUILTIN_LIKELY_MODULES
):
func_in_builtin_modules
=
False
for
m
in
BUILTIN_LIKELY_MODULES
:
for
v
in
m
.
__dict__
.
values
():
func_in_dict
=
func
==
v
if
isinstance
(
func_in_dict
,
list
)
or
isinstance
(
func_in_dict
,
numpy
.
ndarray
):
func_in_dict
=
any
(
func_in_dict
)
if
func_in_dict
:
func_in_builtin_modules
=
True
break
if
func_in_builtin_modules
:
break
if
func_in_builtin_modules
:
translator_logger
.
log
(
translator_logger
.
log
(
2
,
2
,
"Whitelist: {} is part of built-in module and does not have to be transformed."
.
"Whitelist: {} is part of built-in module and does not have to be transformed."
.
...
...
python/paddle/fluid/layers/rnn.py
浏览文件 @
fd7fd0c4
...
@@ -619,7 +619,7 @@ def _rnn_static_graph(cell,
...
@@ -619,7 +619,7 @@ def _rnn_static_graph(cell,
inputs
=
map_structure
(
rnn
.
step_input
,
inputs
)
inputs
=
map_structure
(
rnn
.
step_input
,
inputs
)
states
=
map_structure
(
rnn
.
memory
,
initial_states
)
states
=
map_structure
(
rnn
.
memory
,
initial_states
)
copy_states
=
map_structure
(
lambda
x
:
x
,
states
)
copy_states
=
map_structure
(
lambda
x
:
x
,
states
)
outputs
,
new_states
=
cell
.
call
(
inputs
,
copy_states
,
**
kwargs
)
outputs
,
new_states
=
cell
(
inputs
,
copy_states
,
**
kwargs
)
assert_same_structure
(
states
,
new_states
)
assert_same_structure
(
states
,
new_states
)
if
sequence_length
:
if
sequence_length
:
step_mask
=
rnn
.
step_input
(
mask
)
step_mask
=
rnn
.
step_input
(
mask
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py
0 → 100644
浏览文件 @
fd7fd0c4
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
import
unittest
from
paddle
import
nn
class
Net
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
hidden_size
):
super
(
Net
,
self
).
__init__
()
self
.
lstm
=
nn
.
LSTM
(
in_channels
,
hidden_size
,
direction
=
'bidirectional'
,
num_layers
=
2
)
@
paddle
.
jit
.
to_static
def
forward
(
self
,
x
):
x
,
_
=
self
.
lstm
(
x
)
return
x
class
TestLstm
(
unittest
.
TestCase
):
def
run_lstm
(
self
,
to_static
):
paddle
.
jit
.
ProgramTranslator
().
enable
(
to_static
)
paddle
.
disable_static
()
paddle
.
static
.
default_main_program
().
random_seed
=
1001
paddle
.
static
.
default_startup_program
().
random_seed
=
1001
net
=
Net
(
12
,
2
)
x
=
paddle
.
zeros
((
2
,
10
,
12
))
y
=
net
(
paddle
.
to_tensor
(
x
))
return
y
.
numpy
()
def
test_lstm_to_static
(
self
):
dygraph_out
=
self
.
run_lstm
(
to_static
=
False
)
static_out
=
self
.
run_lstm
(
to_static
=
True
)
self
.
assertTrue
(
np
.
allclose
(
dygraph_out
,
static_out
),
msg
=
'dygraph_out is {}
\n
static_out is
\n
{}'
.
format
(
dygraph_out
,
static_out
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录