未验证 提交 53d3f5eb 编写于 作者: L LielinJiang 提交者: GitHub

add sample code for summary (#33337)

上级 7528b1e8
......@@ -80,6 +80,23 @@ def summary(net, input_size, dtypes=None):
params_info = paddle.summary(lenet, (1, 1, 28, 28))
print(params_info)
# multi input demo
class LeNetMultiInput(LeNet):
def forward(self, inputs, y):
x = self.features(inputs)
if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x + y)
return x
lenet_multi_input = LeNetMultiInput()
params_info = paddle.summary(lenet_multi_input, [(1, 1, 28, 28), (1, 400)],
['float32', 'float32'])
print(params_info)
"""
if isinstance(input_size, InputSpec):
_input_size = tuple(input_size.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册