未验证 提交 abb32bbc 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Fix lstm bug (#27631)

This PR fixed two bugs when converting LSTM in dy2stat:

is_unsupported has a condition can trigger Python syntax error
LSTM API's implementation in _rnn_static_graph doesn't include parameter initialization, which can cause dy2stat error.
上级 b7b1ae86
...@@ -65,12 +65,17 @@ def is_unsupported(func): ...@@ -65,12 +65,17 @@ def is_unsupported(func):
Checks whether the func is supported by dygraph to static graph. Checks whether the func is supported by dygraph to static graph.
""" """
if any(func in m.__dict__.values() for m in BUILTIN_LIKELY_MODULES): for m in BUILTIN_LIKELY_MODULES:
translator_logger.log( for v in m.__dict__.values():
2, func_in_dict = func == v
"Whitelist: {} is part of built-in module and does not have to be transformed.". if isinstance(func_in_dict, (list, numpy.ndarray)):
format(func)) func_in_dict = any(func_in_dict)
return True if func_in_dict:
translator_logger.log(
2,
"Whitelist: {} is part of built-in module and does not have to be transformed.".
format(func))
return True
if is_paddle_func(func): if is_paddle_func(func):
translator_logger.log( translator_logger.log(
......
...@@ -623,7 +623,7 @@ def _rnn_static_graph(cell, ...@@ -623,7 +623,7 @@ def _rnn_static_graph(cell,
inputs = map_structure(rnn.step_input, inputs) inputs = map_structure(rnn.step_input, inputs)
states = map_structure(rnn.memory, initial_states) states = map_structure(rnn.memory, initial_states)
copy_states = map_structure(lambda x: x, states) copy_states = map_structure(lambda x: x, states)
outputs, new_states = cell.call(inputs, copy_states, **kwargs) outputs, new_states = cell(inputs, copy_states, **kwargs)
assert_same_structure(states, new_states) assert_same_structure(states, new_states)
if sequence_length: if sequence_length:
step_mask = rnn.step_input(mask) step_mask = rnn.step_input(mask)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import unittest
from paddle import nn
class Net(nn.Layer):
def __init__(self, in_channels, hidden_size):
super(Net, self).__init__()
self.lstm = nn.LSTM(
in_channels, hidden_size, direction='bidirectional', num_layers=2)
@paddle.jit.to_static
def forward(self, x):
x, _ = self.lstm(x)
return x
class TestLstm(unittest.TestCase):
def run_lstm(self, to_static):
paddle.jit.ProgramTranslator().enable(to_static)
paddle.disable_static()
paddle.static.default_main_program().random_seed = 1001
paddle.static.default_startup_program().random_seed = 1001
net = Net(12, 2)
x = paddle.zeros((2, 10, 12))
y = net(paddle.to_tensor(x))
return y.numpy()
def test_lstm_to_static(self):
dygraph_out = self.run_lstm(to_static=False)
static_out = self.run_lstm(to_static=True)
self.assertTrue(
np.allclose(dygraph_out, static_out),
msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out,
static_out))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册