未验证 提交 75433126 编写于 作者: L LielinJiang 提交者: GitHub

Fix summary bug when calaculating output shape (#31549)

* fix summary bug
上级 c3634c6b
......@@ -341,10 +341,12 @@ def summary_string(model, input_size, dtypes=None):
total_params += summary[layer]["nb_params"]
try:
total_output += np.prod(summary[layer]["output_shape"])
total_output += np.sum(
np.prod(
summary[layer]["output_shape"], axis=-1))
except:
for output_shape in summary[layer]["output_shape"]:
total_output += np.prod(output_shape)
total_output += np.sum(np.prod(output_shape, axis=-1))
if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册