未验证 提交 75433126 编写于 作者: L LielinJiang 提交者: GitHub

Fix summary bug when calaculating output shape (#31549)

* fix summary bug
上级 c3634c6b
...@@ -341,10 +341,12 @@ def summary_string(model, input_size, dtypes=None): ...@@ -341,10 +341,12 @@ def summary_string(model, input_size, dtypes=None):
total_params += summary[layer]["nb_params"] total_params += summary[layer]["nb_params"]
try: try:
total_output += np.prod(summary[layer]["output_shape"]) total_output += np.sum(
np.prod(
summary[layer]["output_shape"], axis=-1))
except: except:
for output_shape in summary[layer]["output_shape"]: for output_shape in summary[layer]["output_shape"]:
total_output += np.prod(output_shape) total_output += np.sum(np.prod(output_shape, axis=-1))
if "trainable" in summary[layer]: if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True: if summary[layer]["trainable"] == True:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册