未验证 提交 4676f03c 编写于 作者: L LielinJiang 提交者: GitHub

fix summary (#27820)

上级 840d54de
......@@ -254,7 +254,7 @@ def summary_string(model, input_size, dtypes=None):
dtype = dtypes[0]
else:
dtype = dtypes
return paddle.rand(list(input_size), dtype)
return paddle.cast(paddle.rand(list(input_size)), dtype)
else:
return [
build_input(i, dtype) for i, dtype in zip(input_size, dtypes)
......
......@@ -501,6 +501,11 @@ class TestModelFunction(unittest.TestCase):
rnn = paddle.nn.LSTM(16, 32, 2)
paddle.summary(rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))])
def test_summary_dtype(self):
input_shape = (3, 1)
net = paddle.nn.Embedding(10, 3, sparse=True)
paddle.summary(net, input_shape, dtypes='int64')
def test_summary_error(self):
with self.assertRaises(TypeError):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册