提交 13481fd2 编写于 作者: M Megvii Engine Team

feat(mge/tools): add support of receptive_field stats for NetworkNode

GitOrigin-RevId: 11ef3354689d343883348d4129bc89db784e3fe0
上级 5a7c30e0
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse import argparse
import json
import logging import logging
import numpy as np import numpy as np
...@@ -14,6 +15,7 @@ import numpy as np ...@@ -14,6 +15,7 @@ import numpy as np
from megengine.core.tensor.dtype import is_quantize from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import ( from megengine.utils.module_stats import (
get_flops_stats,
get_param_stats, get_param_stats,
print_flops_stats, print_flops_stats,
print_params_stats, print_params_stats,
...@@ -89,6 +91,7 @@ def visualize( ...@@ -89,6 +91,7 @@ def visualize(
inp_list = [process_name(var.owner.name) for var in node.inputs] inp_list = [process_name(var.owner.name) for var in node.inputs]
if log_path: if log_path:
# detail format see tensorboard/compat/proto/attr_value.proto
attr = { attr = {
"_output_shapes": AttrValue( "_output_shapes": AttrValue(
list=AttrValue.ListValue( list=AttrValue.ListValue(
...@@ -101,24 +104,20 @@ def visualize( ...@@ -101,24 +104,20 @@ def visualize(
] ]
) )
), ),
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
} }
if hasattr(node, "calc_flops"): flops_stats = get_flops_stats(node, node.inputs, node.outputs)
flops_num = node.calc_flops() if flops_stats is not None:
# add op flops attr # add op flops attr
if log_path: if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue( attr["flops"] = AttrValue(
s=sizeof_fmt(flops_num).encode(encoding="utf-8") s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
)
flops_list.append(
dict(
name=node.name,
class_name=node.type,
input_shapes=[i.shape for i in node.inputs],
output_shapes=[o.shape for o in node.outputs],
flops_num=flops_num,
flops_cum=0,
) )
) flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)
if node.type == "ImmutableTensor": if node.type == "ImmutableTensor":
param_stats = get_param_stats(node.numpy()) param_stats = get_param_stats(node.numpy())
# add tensor size attr # add tensor size attr
...@@ -132,6 +131,7 @@ def visualize( ...@@ -132,6 +131,7 @@ def visualize(
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars: if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue continue
if log_path: if log_path:
node_list.append( node_list.append(
NodeDef( NodeDef(
...@@ -141,14 +141,26 @@ def visualize( ...@@ -141,14 +141,26 @@ def visualize(
attr=attr, attr=attr,
) )
) )
# summary
extra_info = {
"#ops": len(graph.all_oprs),
"#params": len(params_list),
}
total_flops, total_params = None, None total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params: if log_params:
total_param_dims, total_param_size = print_params_stats( total_param_dims, total_param_size = print_params_stats(
params_list, bar_length_max params_list, bar_length_max
) )
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops: if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max) total_flops = print_flops_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
if log_path: if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22)) graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
...@@ -160,21 +172,12 @@ def visualize( ...@@ -160,21 +172,12 @@ def visualize(
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer._get_file_writer().add_graph((graph_def, stepstats)) writer._get_file_writer().add_graph((graph_def, stepstats))
# summary
extra_info = {
"#ops": len(graph.all_oprs),
"#params": len(params_list),
"total_param_dims": sizeof_fmt(total_param_dims),
"total_param_size": sizeof_fmt(total_param_size),
"total_flops": sizeof_fmt(total_flops, suffix="OPs"),
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size),
}
print_summary(**extra_info) print_summary(**extra_info)
# FIXME: remove this after resolving "span dist too large" warning # FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level) _imperative_rt_logger.set_log_level(old_level)
return total_params, total_flops return total_param_size, total_flops
def main(): def main():
......
...@@ -26,61 +26,95 @@ logger = mge.get_logger(__name__) ...@@ -26,61 +26,95 @@ logger = mge.get_logger(__name__)
logger.setLevel("INFO") logger.setLevel("INFO")
CALC_FLOPS = {} _calc_flops_dict = {}
_calc_receptive_field_dict = {}
def _register_modules(*modules):
def _receptive_field_fallback(module, inputs, outputs):
assert not hasattr(module, "_rf")
assert not hasattr(module, "_stride")
if len(inputs) == 0:
# TODO: support other dimension
module._rf = (1, 1)
module._stride = (1, 1)
return module._rf, module._stride
rf, stride = preprocess_receptive_field(module, inputs, outputs)
module._rf = rf
module._stride = stride
return rf, stride
# key tuple, impl_dict, fallback
_iter_list = [
("flops_num", _calc_flops_dict, None),
(
("receptive_field", "stride"),
_calc_receptive_field_dict,
_receptive_field_fallback,
),
]
def _register_dict(*modules, dict=None):
def callback(impl): def callback(impl):
for module in modules: for module in modules:
CALC_FLOPS[module] = impl dict[module] = impl
return impl return impl
return callback return callback
@_register_modules( def register_flops(*modules):
m.Conv2d, return _register_dict(*modules, dict=_calc_flops_dict)
m.ConvTranspose2d,
m.LocalConv2d,
qm.Conv2d, def register_receptive_field(*modules):
qm.ConvRelu2d, return _register_dict(*modules, dict=_calc_receptive_field_dict)
qm.ConvBn2d,
qm.ConvBnRelu2d,
qatm.Conv2d, @register_flops(
qatm.ConvRelu2d, m.Conv1d, m.Conv2d, m.Conv3d,
qatm.ConvBn2d,
qatm.ConvBnRelu2d,
) )
def count_convNd(module, input, output): def flops_convNd(module: m.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0 bias = 1 if module.bias is not None else 0
group = module.groups group = module.groups
ic = input[0].shape[1] ic = inputs[0].shape[1]
oc = output[0].shape[1] oc = outputs[0].shape[1]
goc = oc // group goc = oc // group
gic = ic // group gic = ic // group
N = output[0].shape[0] N = outputs[0].shape[0]
HW = np.prod(output[0].shape[2:]) HW = np.prod(outputs[0].shape[2:])
# N x Cout x H x W x (Cin x Kw x Kh + bias) # N x Cout x H x W x (Cin x Kw x Kh + bias)
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) return N * HW * goc * (gic * np.prod(module.kernel_size) + bias)
@_register_modules(m.ConvTranspose2d) @register_flops(m.ConvTranspose2d)
def count_deconvNd(module, input, output): def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs):
return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size) return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size)
@register_flops(m.Linear)
def flops_linear(module: m.Linear, inputs, outputs):
bias = 1 if module.bias is not None else 0
return np.prod(outputs[0].shape) * module.in_features
@_register_modules(m.Linear, qatm.Linear, qm.Linear)
def count_linear(module, input, output): @register_flops(m.BatchMatMulActivation)
return np.prod(output[0].shape) * module.in_features def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs):
bias = 1 if module.bias is not None else 0
x = inputs[0]
w = module.weight
batch_size = x.shape[0]
n, p = x.shape[1:]
_, m = w.shape[1:]
return n * (p + bias) * m * batch_size
# does not need import qat and quantized module since they inherit from float module. # does not need import qat and quantized module since they inherit from float module.
hook_modules = ( hook_modules = (
m.Conv2d, m.conv._ConvNd,
m.ConvTranspose2d,
m.LocalConv2d,
m.BatchNorm2d,
m.Linear, m.Linear,
m.BatchMatMulActivation,
) )
...@@ -106,28 +140,71 @@ def sizeof_fmt(num, suffix="B"): ...@@ -106,28 +140,71 @@ def sizeof_fmt(num, suffix="B"):
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix) return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix)
def preprocess_receptive_field(module, inputs, outputs):
# TODO: support other dimensions
pre_rf = (
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs),
max(i.owner._rf[1] for i in inputs),
)
pre_stride = (
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs),
max(i.owner._stride[1] for i in inputs),
)
return pre_rf, pre_stride
def get_flops_stats(module, inputs, outputs):
rst = {
"input_shapes": [i.shape for i in inputs],
"output_shapes": [o.shape for o in outputs],
}
valid_flag = False
for key, _dict, fallback in _iter_list:
for _type in _dict:
if isinstance(module, _type):
value = _dict[_type](module, inputs, outputs)
valid_flag = True
break
else:
if fallback is not None:
value = fallback(module, inputs, outputs)
continue
if isinstance(key, tuple):
assert isinstance(value, tuple)
for k, v in zip(key, value):
rst[k] = v
else:
rst[key] = value
if valid_flag:
return rst
else:
return None
return
def print_flops_stats(flops, bar_length_max=20): def print_flops_stats(flops, bar_length_max=20):
flops_list = [i["flops_num"] for i in flops] max_flops_num = max([i["flops_num"] for i in flops] + [0])
max_flops_num = max(flops_list + [0])
# calc total flops and set flops_cum
total_flops_num = 0 total_flops_num = 0
for d in flops: for d in flops:
total_flops_num += int(d["flops_num"]) total_flops_num += int(d["flops_num"])
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs")
for d in flops: for d in flops:
f = d["flops_num"] ratio = d["ratio"] = d["flops_num"] / total_flops_num
d["flops"] = sizeof_fmt(f, suffix="OPs") d["percentage"] = "{:.2f}%".format(ratio * 100)
r = d["ratio"] = f / total_flops_num bar_length = int(d["flops_num"] / max_flops_num * bar_length_max)
d["percentage"] = "{:.2f}%".format(r * 100)
bar_length = int(f / max_flops_num * bar_length_max)
d["bar"] = "#" * bar_length d["bar"] = "#" * bar_length
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs")
header = [ header = [
"name", "name",
"class_name", "class_name",
"input_shapes", "input_shapes",
"output_shapes", "output_shapes",
"receptive_field",
"stride",
"flops", "flops",
"flops_cum", "flops_cum",
"percentage", "percentage",
...@@ -154,8 +231,8 @@ def get_param_stats(param: np.ndarray): ...@@ -154,8 +231,8 @@ def get_param_stats(param: np.ndarray):
param_size = param_dim * nbits // 8 param_size = param_dim * nbits // 8
return { return {
"shape": shape, "shape": shape,
"mean": param.mean(), "mean": "{:.3g}".format(param.mean()),
"std": param.std(), "std": "{:.3g}".format(param.std()),
"param_dim": param_dim, "param_dim": param_dim,
"nbits": nbits, "nbits": nbits,
"size": param_size, "size": param_size,
...@@ -163,21 +240,20 @@ def get_param_stats(param: np.ndarray): ...@@ -163,21 +240,20 @@ def get_param_stats(param: np.ndarray):
def print_params_stats(params, bar_length_max=20): def print_params_stats(params, bar_length_max=20):
max_size = max([d["size"] for d in params] + [0])
total_param_dims, total_param_size = 0, 0 total_param_dims, total_param_size = 0, 0
for d in params: for d in params:
total_param_dims += int(d["param_dim"]) total_param_dims += int(d["param_dim"])
total_param_size += int(d["size"]) total_param_size += int(d["size"])
ratio = d["size"] / total_param_size
d["size"] = sizeof_fmt(d["size"])
d["size_cum"] = sizeof_fmt(total_param_size) d["size_cum"] = sizeof_fmt(total_param_size)
d["ratio"] = ratio
d["percentage"] = "{:.2f}%".format(ratio * 100)
# construct bar
max_ratio = max([d["ratio"] for d in params])
for d in params: for d in params:
bar_length = int(d["ratio"] / max_ratio * bar_length_max) ratio = d["size"] / total_param_size
d["ratio"] = ratio
d["percentage"] = "{:.2f}%".format(ratio * 100)
bar_length = int(d["size"] / max_size * bar_length_max)
d["size_bar"] = "#" * bar_length d["size_bar"] = "#" * bar_length
d["size"] = sizeof_fmt(d["size"])
param_size = sizeof_fmt(total_param_size) param_size = sizeof_fmt(total_param_size)
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,)) params.append(dict(name="total", param_dim=total_param_dims, size=param_size,))
...@@ -225,26 +301,14 @@ def module_stats( ...@@ -225,26 +301,14 @@ def module_stats(
:param log_flops: whether print and record op flops. :param log_flops: whether print and record op flops.
""" """
def module_stats_hook(module, input, output, name=""): def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0] class_name = str(module.__class__).split(".")[-1].split("'")[0]
flops_fun = CALC_FLOPS.get(type(module)) flops_stats = get_flops_stats(module, inputs, outputs)
if callable(flops_fun): if flops_stats is not None:
flops_num = flops_fun(module, input, output) flops_stats["name"] = name
flops_stats["class_name"] = class_name
if not isinstance(output, (list, tuple)): flops.append(flops_stats)
output = [output]
flops.append(
dict(
name=name,
class_name=class_name,
input_shapes=[i.shape for i in input],
output_shapes=[o.shape for o in output],
flops_num=flops_num,
flops_cum=0,
)
)
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
w = module.weight w = module.weight
...@@ -278,19 +342,22 @@ def module_stats( ...@@ -278,19 +342,22 @@ def module_stats(
for h in hooks: for h in hooks:
h.remove() h.remove()
total_flops, total_params = 0, 0 extra_info = {
"#params": len(params),
}
total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params: if log_params:
total_param_dims, total_param_size = print_params_stats(params, bar_length_max) total_param_dims, total_param_size = print_params_stats(params, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops: if log_flops:
total_flops = print_flops_stats(flops, bar_length_max) total_flops = print_flops_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
extra_info = {
"#params": len(params),
"total_param_dims": sizeof_fmt(total_param_dims),
"total_param_size": sizeof_fmt(total_param_size),
"total_flops": sizeof_fmt(total_flops, suffix="OPs"),
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size),
}
print_summary(**extra_info) print_summary(**extra_info)
return total_params, total_flops return total_param_size, total_flops
...@@ -18,6 +18,11 @@ from ..core.ops import builtin ...@@ -18,6 +18,11 @@ from ..core.ops import builtin
from ..core.tensor.megbrain_graph import InputNode from ..core.tensor.megbrain_graph import InputNode
from ..tensor import Tensor from ..tensor import Tensor
from .comp_graph_tools import replace_vars from .comp_graph_tools import replace_vars
from .module_stats import (
preprocess_receptive_field,
register_flops,
register_receptive_field,
)
class NetworkNode: class NetworkNode:
...@@ -225,8 +230,21 @@ class Elemwise(OpNode): ...@@ -225,8 +230,21 @@ class Elemwise(OpNode):
type = "Elemwise" type = "Elemwise"
opdef = builtin.Elemwise opdef = builtin.Elemwise
def calc_flops(self):
return np.prod(self.outputs[0].shape) class ElemwiseMultiType(OpNode):
type = "ElemwiseMultiType"
opdef = builtin.ElemwiseMultiType
@classmethod
def load(cls, opr):
obj = super(ElemwiseMultiType, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj
@register_flops(Elemwise, ElemwiseMultiType)
def flops_elemwise(opnode: Elemwise, inputs, outputs):
return np.prod(outputs[0].shape)
class Reduce(OpNode): class Reduce(OpNode):
...@@ -255,20 +273,24 @@ class MatrixMul(OpNode): ...@@ -255,20 +273,24 @@ class MatrixMul(OpNode):
type = "MatrixMul" type = "MatrixMul"
opdef = builtin.MatrixMul opdef = builtin.MatrixMul
def calc_flops(self):
assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2 @register_flops(MatrixMul)
mid_shape = self.inputs[0].shape[1] def flops_matmul(opnode: MatrixMul, inputs, outputs):
return np.prod(self.outputs[0].shape) * mid_shape assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
mid_shape = inputs[0].shape[1]
return np.prod(outputs[0].shape) * mid_shape
class BatchedMatrixMul(OpNode): class BatchedMatrixMul(OpNode):
type = "BatchedMatmul" type = "BatchedMatmul"
opdef = builtin.BatchedMatrixMul opdef = builtin.BatchedMatrixMul
def calc_flops(self):
assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3 @register_flops(BatchedMatrixMul)
mid_shape = self.inputs[0].shape[2] def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
return np.prod(self.outputs[0].shape) * mid_shape assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
mid_shape = inputs[0].shape[2]
return np.prod(outputs[0].shape) * mid_shape
class Dot(OpNode): class Dot(OpNode):
...@@ -285,18 +307,6 @@ class ConvolutionForward(OpNode): ...@@ -285,18 +307,6 @@ class ConvolutionForward(OpNode):
type = "Convolution" type = "Convolution"
opdef = builtin.Convolution opdef = builtin.Convolution
def calc_flops(self):
param_W_shape = self.inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(self.outputs[0].shape)
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh)
class ConvolutionBackwardData(OpNode): class ConvolutionBackwardData(OpNode):
type = "ConvTranspose" type = "ConvTranspose"
...@@ -343,17 +353,41 @@ class ConvBiasForward(OpNode): ...@@ -343,17 +353,41 @@ class ConvBiasForward(OpNode):
obj.params["dtype"] = opr.outputs[0].dtype obj.params["dtype"] = opr.outputs[0].dtype
return obj return obj
def calc_flops(self):
param_W_shape = self.inputs[1].shape @register_flops(
kh = param_W_shape[-2] ConvolutionForward, ConvBiasForward,
kw = param_W_shape[-1] )
if len(param_W_shape) == 5: def flops_conv(opnode: ConvolutionForward, inputs, outputs):
num_input = param_W_shape[2] param_W_shape = inputs[1].shape
else: kh = param_W_shape[-2]
num_input = param_W_shape[1] kw = param_W_shape[-1]
NCHW = np.prod(self.outputs[0].shape) if len(param_W_shape) == 5:
# N x Cout x H x W x (Cin x Kw x Kh + bias) num_input = param_W_shape[2]
return NCHW * (num_input * kw * kh + 1) else:
num_input = param_W_shape[1]
NCHW = np.prod(outputs[0].shape)
bias = 1 if isinstance(opnode, ConvBiasForward) else 0
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh + bias)
@register_receptive_field(ConvolutionForward, ConvBiasForward)
def receptive_field(opnode: ConvolutionForward, inputs, outputs):
pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
param_W_shape = inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
rf = (
kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
)
stride = (
opnode.params["stride_h"] * pre_stride[0],
opnode.params["stride_w"] * pre_stride[1],
)
opnode._rf = rf
opnode._stride = stride
return rf, stride
class BatchConvBiasForward(OpNode): class BatchConvBiasForward(OpNode):
...@@ -652,20 +686,6 @@ class AssertEqual(OpNode): ...@@ -652,20 +686,6 @@ class AssertEqual(OpNode):
opdef = builtin.AssertEqual opdef = builtin.AssertEqual
class ElemwiseMultiType(OpNode):
type = "ElemwiseMultiType"
opdef = builtin.ElemwiseMultiType
@classmethod
def load(cls, opr):
obj = super(ElemwiseMultiType, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj
def calc_flops(self):
return np.prod(self.outputs[0].shape)
class CvtColorForward(OpNode): class CvtColorForward(OpNode):
type = "CvtColor" type = "CvtColor"
opdef = builtin.CvtColor opdef = builtin.CvtColor
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册