提交 ec348147 编写于 作者: H Hui Zhang

fix layer tools

上级 b5339633
...@@ -56,8 +56,8 @@ def print_params(model, print_func=print): ...@@ -56,8 +56,8 @@ def print_params(model, print_func=print):
if print_func: if print_func:
print_func(msg) print_func(msg)
if print_func: if print_func:
total = total / 1024**3 total = total / 1024**2
print_func(f"Total parameters: {num_params}, {total}G elements.") print_func(f"Total parameters: {num_params}, {total:.4f}M elements.")
def gradient_norm(layer: nn.Layer): def gradient_norm(layer: nn.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册