提交 fea1bba2 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge/tools): fix module status error

GitOrigin-RevId: 38e3eeb34bad42a7b28340c544d19b7d3f079b8a
上级 6cd01d5a
...@@ -14,6 +14,7 @@ from collections import namedtuple ...@@ -14,6 +14,7 @@ from collections import namedtuple
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
import megengine as mge
from megengine.core.tensor.dtype import is_quantize from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import ( from megengine.utils.module_stats import (
...@@ -119,7 +120,9 @@ def visualize( ...@@ -119,7 +120,9 @@ def visualize(
flops_list = [] flops_list = []
params_list = [] params_list = []
activations_list = [] activations_list = []
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) total_stats = namedtuple(
"total_stats", ["param_size", "param_dims", "flops", "act_size", "act_dims"]
)
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
for node in tqdm(graph.all_oprs): for node in tqdm(graph.all_oprs):
...@@ -166,14 +169,14 @@ def visualize( ...@@ -166,14 +169,14 @@ def visualize(
flops_list.append(flops_stats) flops_list.append(flops_stats)
if cal_activations: if cal_activations:
acts = get_activation_stats(node_oup.numpy(), has_input=has_input) acts = get_activation_stats(node_oup, has_input=has_input)
acts["name"] = node.name acts["name"] = node.name
acts["class_name"] = node.type acts["class_name"] = node.type
activations_list.append(acts) activations_list.append(acts)
if cal_params: if cal_params:
if node.type == "ImmutableTensor": if node.type == "ImmutableTensor":
param_stats = get_param_stats(node.numpy()) param_stats = get_param_stats(node_oup)
# add tensor size attr # add tensor size attr
if log_path: if log_path:
attr["size"] = AttrValue( attr["size"] = AttrValue(
...@@ -248,7 +251,11 @@ def visualize( ...@@ -248,7 +251,11 @@ def visualize(
return ( return (
total_stats( total_stats(
param_size=total_param_size, flops=total_flops, act_size=total_act_size, param_size=total_param_size,
param_dims=total_param_dims,
flops=total_flops,
act_size=total_act_size,
act_dims=total_act_dims,
), ),
stats_details( stats_details(
params=params_list, flops=flops_list, activations=activations_list params=params_list, flops=flops_list, activations=activations_list
...@@ -263,6 +270,10 @@ def main(): ...@@ -263,6 +270,10 @@ def main():
) )
parser.add_argument("model_path", help="dumped model path.") parser.add_argument("model_path", help="dumped model path.")
parser.add_argument("--log_path", help="tensorboard log path.") parser.add_argument("--log_path", help="tensorboard log path.")
parser.add_argument(
"--load_input_data",
help="load input data from pickle file; it should be a numpy array or a dict of numpy array",
)
parser.add_argument( parser.add_argument(
"--bar_length_max", "--bar_length_max",
type=int, type=int,
...@@ -295,6 +306,19 @@ def main(): ...@@ -295,6 +306,19 @@ def main():
help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.", help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.",
) )
args = parser.parse_args() args = parser.parse_args()
if args.load_input_data:
logger.info("load data from {}".format(args.load_input_data))
data = mge.load(args.load_input_data)
if isinstance(data, dict):
for v in data.values():
assert isinstance(
v, np.ndarray
), "data should provide ndarray; got {} instead".format(v)
args.inp_dict = data
elif isinstance(data, np.ndarray):
args.input = data
else:
logger.error("input data should be a numpy array or a dict of numpy array")
if args.all: if args.all:
args.cal_params = True args.cal_params = True
args.cal_flops = True args.cal_flops = True
...@@ -304,6 +328,7 @@ def main(): ...@@ -304,6 +328,7 @@ def main():
args.log_path = "./log" args.log_path = "./log"
kwargs = vars(args) kwargs = vars(args)
kwargs.pop("all") kwargs.pop("all")
kwargs.pop("load_input_data")
visualize(**kwargs) visualize(**kwargs)
......
...@@ -113,7 +113,12 @@ def flops_norm(module: m.Linear, inputs, outputs): ...@@ -113,7 +113,12 @@ def flops_norm(module: m.Linear, inputs, outputs):
@register_flops(m.AvgPool2d, m.MaxPool2d) @register_flops(m.AvgPool2d, m.MaxPool2d)
def flops_pool(module: m.AvgPool2d, inputs, outputs): def flops_pool(module: m.AvgPool2d, inputs, outputs):
return np.prod(outputs[0].shape) * (module.kernel_size ** 2) kernel_sum = 0
if isinstance(module.kernel_size, tuple) and len(module.kernel_size) == 2:
kernel_sum = np.prod(module.kernel_size)
else:
kernel_sum = module.kernel_size ** 2
return np.prod(outputs[0].shape) * kernel_sum
@register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d) @register_flops(m.AdaptiveAvgPool2d, m.AdaptiveMaxPool2d)
...@@ -157,12 +162,12 @@ hook_modules = ( ...@@ -157,12 +162,12 @@ hook_modules = (
def _mean(inp): def _mean(inp):
inp = mge.tensor(inp) inp = mge.tensor(inp).astype(np.float32)
return F.mean(inp).numpy() return F.mean(inp).numpy()
def _std(inp): def _std(inp):
inp = mge.tensor(inp) inp = mge.tensor(inp).astype(np.float32)
return F.std(inp).numpy() return F.std(inp).numpy()
...@@ -337,7 +342,7 @@ def print_param_stats(params): ...@@ -337,7 +342,7 @@ def print_param_stats(params):
) )
def get_activation_stats(output: np.ndarray, has_input=False): def get_activation_stats(output: Tensor, has_input=False):
out_shape = output.shape out_shape = output.shape
activations_dtype = np.dtype(output.dtype) activations_dtype = np.dtype(output.dtype)
nbits = get_dtype_bit(activations_dtype.name) nbits = get_dtype_bit(activations_dtype.name)
...@@ -351,8 +356,8 @@ def get_activation_stats(output: np.ndarray, has_input=False): ...@@ -351,8 +356,8 @@ def get_activation_stats(output: np.ndarray, has_input=False):
"size": act_size, "size": act_size,
} }
if has_input: if has_input:
activation_stats["mean"] = "{:.3g}".format(output.mean()) activation_stats["mean"] = "{:.3g}".format(_mean(output))
activation_stats["std"] = "{:.3g}".format(output.std()) activation_stats["std"] = "{:.3g}".format(_std(output))
return activation_stats return activation_stats
...@@ -462,21 +467,21 @@ def module_stats( ...@@ -462,21 +467,21 @@ def module_stats(
if cal_params: if cal_params:
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
w = module.weight w = module.weight
param_stats = get_param_stats(w.numpy()) param_stats = get_param_stats(w)
param_stats["name"] = name + "-w" param_stats["name"] = name + "-w"
params.append(param_stats) params.append(param_stats)
if hasattr(module, "bias") and module.bias is not None: if hasattr(module, "bias") and module.bias is not None:
b = module.bias b = module.bias
param_stats = get_param_stats(b.numpy()) param_stats = get_param_stats(b)
param_stats["name"] = name + "-b" param_stats["name"] = name + "-b"
params.append(param_stats) params.append(param_stats)
if cal_activations: if cal_activations:
if not isinstance(outputs, (tuple, list)): if not isinstance(outputs, (tuple, list)):
output = outputs.numpy() output = outputs
else: else:
output = outputs[0].numpy() output = outputs[0]
activation_stats = get_activation_stats(output, has_inputs) activation_stats = get_activation_stats(output, has_inputs)
activation_stats["name"] = name activation_stats["name"] = name
activation_stats["class_name"] = class_name activation_stats["class_name"] = class_name
...@@ -486,7 +491,9 @@ def module_stats( ...@@ -486,7 +491,9 @@ def module_stats(
flops = [] flops = []
hooks = [] hooks = []
activations = [] activations = []
total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"]) total_stats = namedtuple(
"total_stats", ["param_size", "param_dims", "flops", "act_size", "act_dims"]
)
stats_details = namedtuple("module_stats", ["params", "flops", "activations"]) stats_details = namedtuple("module_stats", ["params", "flops", "activations"])
for (name, module) in model.named_modules(): for (name, module) in model.named_modules():
...@@ -536,7 +543,7 @@ def module_stats( ...@@ -536,7 +543,7 @@ def module_stats(
if logging_to_stdout: if logging_to_stdout:
print_activations_stats(activations, has_inputs) print_activations_stats(activations, has_inputs)
if cal_flops and cal_params: if cal_flops and cal_params and total_param_size != 0:
extra_info["flops/param_size"] = "{:3.3f}".format( extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size total_flops / total_param_size
) )
...@@ -545,7 +552,11 @@ def module_stats( ...@@ -545,7 +552,11 @@ def module_stats(
return ( return (
total_stats( total_stats(
param_size=total_param_size, flops=total_flops, act_size=total_act_size, param_size=total_param_size,
param_dims=total_param_dims,
flops=total_flops,
act_size=total_act_size,
act_dims=total_act_dims,
), ),
stats_details(params=params, flops=flops, activations=activations), stats_details(params=params, flops=flops, activations=activations),
) )
...@@ -21,16 +21,10 @@ def test_module_stats(): ...@@ -21,16 +21,10 @@ def test_module_stats():
total_stats, stats_details = module_stats(net, input_shapes=input_shape) total_stats, stats_details = module_stats(net, input_shapes=input_shape)
x1 = np.random.random((1, 3, 224, 224)).astype("float32") x1 = np.random.random((1, 3, 224, 224)).astype("float32")
gt_flops, gt_acts = net.get_stats(mge.tensor(x1)) gt_flops, gt_acts = net.get_stats(mge.tensor(x1))
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
gt_flops,
gt_acts,
)
total_stats, stats_details = module_stats(net, inputs=x1) total_stats, stats_details = module_stats(net, inputs=x1)
assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == ( assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
gt_flops,
gt_acts,
)
class BasicBlock(M.Module): class BasicBlock(M.Module):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册