未验证 提交 1de3cdd0 编写于 作者: L LielinJiang 提交者: GitHub

Fix summary api for rnn gru lstm (#28566)

* fix summary for rnn gru lstm
上级 a24d1868
......@@ -244,6 +244,9 @@ def summary_string(model, input_size, dtypes=None):
(not (layer == model) or depth < 1)):
hooks.append(layer.register_forward_post_hook(hook))
# For rnn, gru and lstm layer
elif hasattr(layer, 'could_use_cudnn') and layer.could_use_cudnn:
hooks.append(layer.register_forward_post_hook(hook))
if isinstance(input_size, tuple):
input_size = [input_size]
......
......@@ -295,6 +295,12 @@ class TestModel(unittest.TestCase):
np.testing.assert_equal(output[0].shape[0], len(self.test_dataset))
fluid.disable_dygraph()
def test_summary_gpu(self):
paddle.disable_static(self.device)
rnn = paddle.nn.LSTM(16, 32, 2)
params_info = paddle.summary(
rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))])
class MyModel(paddle.nn.Layer):
def __init__(self):
......@@ -512,14 +518,33 @@ class TestModelFunction(unittest.TestCase):
model.summary(input_size=(20), dtype='float32')
def test_summary_nlp(self):
paddle.enable_static()
def _get_param_from_state_dict(state_dict):
params = 0
for k, v in state_dict.items():
params += np.prod(v.numpy().shape)
return params
nlp_net = paddle.nn.GRU(input_size=2,
hidden_size=3,
num_layers=3,
direction="bidirectional")
paddle.summary(nlp_net, (1, 1, 2))
rnn = paddle.nn.LSTM(16, 32, 2)
paddle.summary(rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))])
params_info = paddle.summary(
rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))])
gt_params = _get_param_from_state_dict(rnn.state_dict())
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)
rnn = paddle.nn.GRU(16, 32, 2, direction='bidirectional')
params_info = paddle.summary(rnn, (4, 23, 16))
gt_params = _get_param_from_state_dict(rnn.state_dict())
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)
rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional')
params_info = paddle.summary(rnn, (4, 23, 16))
gt_params = _get_param_from_state_dict(rnn.state_dict())
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)
def test_summary_dtype(self):
input_shape = (3, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册