From 1de3cdd0abd947f2830915e5f2d9bedcb7297c98 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Mon, 16 Nov 2020 11:26:56 +0800 Subject: [PATCH] Fix summary api for rnn gru lstm (#28566) * fix summary for rnn gru lstm --- python/paddle/hapi/model_summary.py | 3 +++ python/paddle/tests/test_model.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index c6288ea40c..babbe962a9 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -244,6 +244,9 @@ def summary_string(model, input_size, dtypes=None): (not (layer == model) or depth < 1)): hooks.append(layer.register_forward_post_hook(hook)) + # For rnn, gru and lstm layer + elif hasattr(layer, 'could_use_cudnn') and layer.could_use_cudnn: + hooks.append(layer.register_forward_post_hook(hook)) if isinstance(input_size, tuple): input_size = [input_size] diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index a3b33d6f25..ab7a3654e5 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -295,6 +295,12 @@ class TestModel(unittest.TestCase): np.testing.assert_equal(output[0].shape[0], len(self.test_dataset)) fluid.disable_dygraph() + def test_summary_gpu(self): + paddle.disable_static(self.device) + rnn = paddle.nn.LSTM(16, 32, 2) + params_info = paddle.summary( + rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))]) + class MyModel(paddle.nn.Layer): def __init__(self): @@ -512,14 +518,33 @@ class TestModelFunction(unittest.TestCase): model.summary(input_size=(20), dtype='float32') def test_summary_nlp(self): - paddle.enable_static() + def _get_param_from_state_dict(state_dict): + params = 0 + for k, v in state_dict.items(): + params += np.prod(v.numpy().shape) + return params + nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3, direction="bidirectional") paddle.summary(nlp_net, (1, 1, 2)) + rnn = paddle.nn.LSTM(16, 32, 2) - paddle.summary(rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))]) + params_info = paddle.summary( + rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))]) + gt_params = _get_param_from_state_dict(rnn.state_dict()) + np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0) + + rnn = paddle.nn.GRU(16, 32, 2, direction='bidirectional') + params_info = paddle.summary(rnn, (4, 23, 16)) + gt_params = _get_param_from_state_dict(rnn.state_dict()) + np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0) + + rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional') + params_info = paddle.summary(rnn, (4, 23, 16)) + gt_params = _get_param_from_state_dict(rnn.state_dict()) + np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0) def test_summary_dtype(self): input_shape = (3, 1) -- GitLab