From e4d42cd887d54b7db36e50ada512f3c0ed08b01d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 15 Dec 2021 19:50:09 +0800 Subject: [PATCH] feat(imperative): add more tools for megengine GitOrigin-RevId: 175d0f7f57a668dee196a34a1e153d828575cd28 --- imperative/python/megengine/tools/README.md | 166 +++++- .../python/megengine/tools/benchmark_op.py | 476 ++++++++++++++++ .../megengine/tools/dump_with_testcase_mge.py | 528 ++++++++++++++++++ 3 files changed, 1166 insertions(+), 4 deletions(-) create mode 100644 imperative/python/megengine/tools/benchmark_op.py create mode 100755 imperative/python/megengine/tools/dump_with_testcase_mge.py diff --git a/imperative/python/megengine/tools/README.md b/imperative/python/megengine/tools/README.md index a2ce3de95..8b202d1b9 100644 --- a/imperative/python/megengine/tools/README.md +++ b/imperative/python/megengine/tools/README.md @@ -1,8 +1,166 @@ # MegEngine Tools -This directory contains executable python files. -Use these files in the following way (replace `xxx` to specific file name, like `network_visualize`): +MegEngine 相关的工具汇总。使用方法如下(可将 `xxx` 替换成任一脚本文件,如 `network_visualize`): -``` +```bash python -m megengine.tools.xxx -``` \ No newline at end of file +``` + +工具列表: + +### accuracy_shake_var_tree + +将精度抖动分析结果构造成树结构,方便锁定引起抖动的根节点,以及查找依赖关系。 + +输入: compare_binary_iodump 的输出存入到的一个文件 + +输出: 第一个出现结果不一致的输出结点 + +执行命令: accuracy_shake_var_tree 中定义了一些函数组件,可按需集成到实际代码中。下面有一个测试代码: + +```python +import megengine.tools.accuracy_shake_var_tree as st + +r = st.parse('diff.txt') +for key, value in r.items(): + n = st.varNode.get_varNode(key) + n.show_src_info() + print("reference nodes:") + for i in n.get_reference_list(): + print(i.id) +``` + +### benchmark_op + +逐个运行 functional op(并不是所有的 functional op),对比 MegEngine 与 PyTorch 的性能,通过量化结果来指导如何进行下一步的优化。 + +输入: 无 + +输出: 打印一个列表,对比在小输入和大输入的情况下 MegEngine 和 Pytorch 执行一些 functional op 的速度对比 + +执行命令: `python3 -m megengine.tools.benchmark_op` + +### compare_binary_iodump + +分析同一模型在不同平台下给定相同输入之后的输出是否完全一致。 + +输入: 两个目录(假设分别为 expect/ 和 actual/),分别存有不同平台下运行的 tensor 结果 + +输出: 打印所有的输出 tensor 信息,如果某个 tensor 在两个平台上的值不一致,那么会打印出第一个不一致的值 + +执行命令: `python3 -m megengine.tools.compare_binary_iodump expect/ actual/` + +### cpu_evaluation_tools + +分析多个模型在目标芯片上的运行性能 + +输入:MegEngine 模型文件 + +输出:根据不同模型的加权,输出芯片性能分数 + +执行命令:python3 ./cpu_evaluation_tools.py --load_and_run_file ./load_and_run --models_dir ./cpu_models/ + +### draw_graph + +用来查看静态图的 op 序列,有助于理解 MegEngine 的静态图在动态图的基础上做了哪些优化。 + +输入: `megengine.core.tensor.megbrain_graph.Graph._to_json` 得出的静态图描述文件,为 json 格式 + +输出: 一个 dot 文件,可通过 dot 命令绘制出图片 + +执行命令: + +```bash +python3 -m megengine.tools.draw_graph -i dump.json -o dump.dot +dot -Tpng dump.dot -o dump.png +``` + +### dump_with_testcase_mge + +将待测数据提前注入模型文件,并在本地运行得到期望结果,可与实际运行的结果进行比对以检查是否出错。 + +输入: 一个 MegEngine 模型文件,可选一些 npy 文件作为模型输入(也可以随机生成输入,如下面的命令示例) + +输出: 一个带输入的 MegEngine 模型文件 + +执行命令: `python3 -m megengine.tools.dump_with_testcase_mge model.mge -d "#rand(0,255,14,2)"` + +### graph_info_analyze + +将图和内存信息的 json 文件的文件夹 logs 转换为 TensorBoard 的输入文件夹 logs_p。以便 TensorBoard 对图结构以及内存信息进行可视化。 + +输入: 图和内存信息的 json 文件的文件夹 + +输出: TensorBoard 的输入文件夹 + +执行命令: `python3 -m megengine.tools.graph_info_analyze -i logs -o logs_p` + +### load_network_and_run + +python 版本的 load_and_run。 + +输入: MegEngine 的模型文件,可选一些 npy 文件作为模型输入 + +输出: 模型执行并打印一些测速信息 + +执行命令: `python3 -m megengine.tools.load_network_and_run model.mge --iter 10` + +### network_visualize + +1. 分析给定的 MegEngine 模型中参数量信息,包括 shape、dtype、mean、std 以及 size 占比等。 +2. 分析给定的 MegEngine 模型中算子 FLOPs 计算量以及占比,还有算子的 inputs/outputs shape、感受野、stride 等。 + +输入: MegEngine 的模型文件 + +输出: 模型中的参数量信息或计算量信息 + +执行命令: + +```bash +# 分析参数量 +python3 -m megengine.tools.network_visualize model.mge --cal_params --logging_to_stdout + +# 分析计算量 +python3 -m megengine.tools.network_visualize model.mge --cal_flops --logging_to_stdout +``` + +### profile_analyze + +对于 load_and_run --profile 运行模型生成的 profile.json 文件或者 trace 模式下开启 profiling 功能并通过 trace.get_profile() 得到的 json 文件进行分析,得到静态图中算子的时间和显存占比等信息,以表格形式呈现。 + +输入: load_and_run 生成的 profile 文件 + +输出: 一个按照参数在输入文件中筛选得出的数据表格 + +执行命令: + +```bash +# 生成供分析的 json 文件 +python3 -m megengine.tools.load_network_and_run model.mge --warm-up --iter 10 --profile profile.json + +#分析耗时前 3 的单个算子 +python3 -m megengine.tools.profile_analyze profile.json -t 3 + +#筛选用时超过 10us 的 conv 按 flops 排序 +python3 -m megengine.tools.profile_analyze profile.json -t 3 --order-by +flops --min-time 1e-5 --type ConvolutionForward +``` + +### profiler + +对给定的训练程序,记录训练过程并以通用格式存储,可在浏览器上可视化。 + +输入: 需要一个 MegEngine 的训练程序(称之为 train.py,其中包含一个典型的 MegEngine 训练过程) + +输出: 一些记录 profile 过程的 json 文件,默认在 profile 子目录下,可用 https://ui.perfetto.dev/ 进行加载并且可视化 + +执行命令: `python3 -m megengine.tools.profiler train.py` + +### svg_viewer + +查看 MegEngine 生成的显存占用图,可以帮助用户了解显存使用情况. + +输入: 显存占用的 svg 图片 + +输出: 网页展示的可视化 + +执行命令: `python3 -m megengine.tools.svg_viewer` diff --git a/imperative/python/megengine/tools/benchmark_op.py b/imperative/python/megengine/tools/benchmark_op.py new file mode 100644 index 000000000..19bf0ff9e --- /dev/null +++ b/imperative/python/megengine/tools/benchmark_op.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as TF +from tabulate import tabulate + +import megengine as mge +import megengine.functional as MF +import megengine.module as MM + +module_cache = { + "conv2d": (MM.Conv2d(32, 32, 3, 1, 0), nn.Conv2d(32, 32, 3, 1, 0).cuda()), + "dw_conv2d": ( + MM.Conv2d(32, 32, 3, 1, 0, groups=32), + nn.Conv2d(32, 32, 3, 1, 0, groups=32).cuda(), + ), + "conv3d": (MM.Conv3d(32, 32, 3, 1, 0), nn.Conv3d(32, 32, 3, 1, 0).cuda()), + "ConvTranspose2d": ( + MM.ConvTranspose2d(32, 32, 3, 1, 0), + nn.ConvTranspose2d(32, 32, 3, 1, 0).cuda(), + ), + "BatchNorm2d": (MM.BatchNorm2d(64), nn.BatchNorm2d(64).cuda()), + "Linear": (MM.Linear(1000, 1000), nn.Linear(1000, 1000).cuda()), +} + +test_cases = [ + # (mge op, torch op, small inps, large inps, unpack_inps, rep) + ( + "adaptive_avg_pool2d", + lambda x: MF.adaptive_avg_pool2d(x, (7, 7)), + lambda x: TF.adaptive_avg_pool2d(x, (7, 7)), + [(2, 32, 16, 16)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "adaptive_max_pool2d", + lambda x: MF.adaptive_max_pool2d(x, (7, 7)), + lambda x: TF.adaptive_max_pool2d(x, (7, 7)), + [(2, 32, 16, 16)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ("argsort", MF.argsort, torch.argsort, [(1000,)], [(1000, 1000),], True, 1000), + ( + "avg_pool2d", + lambda x: MF.avg_pool2d(x, 2), + lambda x: TF.avg_pool2d(x, 2), + [(2, 32, 16, 16)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "broadcast", + lambda x: MF.broadcast_to(x, (5,) + x.shape), + lambda x: torch.broadcast_to(x, (5,) + x.shape), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "batchedmatmul", + MF.matmul, + torch.matmul, + [(8, 64, 32), (8, 32, 64)], + [(8, 2048, 512), (8, 512, 2048)], + True, + 1000, + ), + ( + "batchnrom2d", + lambda x: module_cache["BatchNorm2d"][0](x), + lambda x: module_cache["BatchNorm2d"][1](x), + [(2, 64, 16, 16)], + [(64, 64, 128, 128)], + True, + 1000, + ), + ( + "concat", + MF.concat, + torch.cat, + [(20, 100), (50, 100), (30, 100)], + [(64, 512, 16, 16), (64, 512, 16, 16), (64, 512, 16, 16)], + False, + 1000, + ), + ( + "conv2d", + lambda x: module_cache["conv2d"][0](x), + lambda x: module_cache["conv2d"][1](x), + [(2, 32, 16, 16)], + [(32, 32, 128, 128)], + True, + 1000, + ), + ( + "conv3d", + lambda x: module_cache["conv3d"][0](x), + lambda x: module_cache["conv3d"][1](x), + [(2, 32, 8, 8, 8)], + [(32, 32, 16, 16, 16)], + True, + 1000, + ), + ( + "convTranspose2d", + lambda x: module_cache["ConvTranspose2d"][0](x), + lambda x: module_cache["ConvTranspose2d"][1](x), + [(2, 32, 16, 16)], + [(32, 32, 128, 128)], + True, + 1000, + ), + ( + "dropout", + lambda x: MF.dropout(x, 0.5), + TF.dropout, + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "dw_conv2d", + lambda x: module_cache["dw_conv2d"][0](x), + lambda x: module_cache["dw_conv2d"][1](x), + [(2, 32, 16, 16)], + [(32, 32, 128, 128)], + True, + 1000, + ), + ( + "elemwise.unary", + MF.log, + torch.log, + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "elemwise.binary", + MF.add, + torch.add, + [(100, 100), (100, 100)], + [(64, 512, 16, 16), (64, 512, 16, 16)], + True, + 1000, + ), + ( + "expand_dims", + lambda x: MF.expand_dims(x, 0), + lambda x: torch.unsqueeze(x, 0), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ("gelu", MF.gelu, TF.gelu, [(100, 100)], [(64, 512, 16, 16)], True, 1000), + ("hswish", MF.hswish, TF.hardswish, [(100, 100)], [(64, 512, 16, 16)], True, 1000), + ( + "hsigmoid", + MF.hsigmoid, + TF.hardsigmoid, + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ("isinf", MF.isinf, torch.isinf, [(100, 100)], [(64, 512, 16, 16)], True, 1000), + ( + "indeixngMultiAxisVec", + lambda x: x[[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]], + lambda x: x[[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]], + [(10, 10, 10, 10)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "logsigmoid", + MF.logsigmoid, + TF.logsigmoid, + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "leaky_relu", + lambda x: MF.leaky_relu(x, 0.5), + lambda x: TF.leaky_relu(x, 0.5), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "linear", + lambda x: module_cache["Linear"][0](x), + lambda x: module_cache["Linear"][1](x), + [(10, 1000)], + [(64, 128, 1000)], + True, + 1000, + ), + ("matinv", MF.matinv, torch.inverse, [(10, 10)], [(30, 30)], True, 1000), + ( + "matmul", + MF.matmul, + torch.matmul, + [(64, 32), (32, 64)], + [(2048, 1024), (1024, 2048)], + True, + 1000, + ), + ( + "max_pool2d", + lambda x: MF.max_pool2d(x, 2), + lambda x: TF.max_pool2d(x, 2), + [(2, 32, 16, 16)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "normal", + lambda x: mge.random.normal(0, 1, x.shape), + lambda x: torch.randn(x.shape, device="cuda"), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "prelu", + MF.prelu, + TF.prelu, + [(100, 100), (1,)], + [(64, 512, 16, 16), (1,)], + True, + 1000, + ), + ( + "reduce.max", + lambda x: MF.max(x, 0), + lambda x: torch.max(x, 0), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "reduce.mean", + lambda x: MF.mean(x, 0), + lambda x: torch.mean(x, 0), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "reduce.mean", + lambda x: MF.mean(x, 0), + lambda x: torch.mean(x, 0), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ("relu", MF.relu, TF.relu, [(100, 100)], [(64, 512, 16, 16)], True, 1000), + ("relu6", MF.relu6, TF.relu6, [(100, 100)], [(64, 512, 16, 16)], True, 1000), + ( + "repeat", + lambda x: MF.repeat(x, 5), + lambda x: torch.repeat_interleave(x, 5), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ("silu", MF.silu, TF.silu, [(100, 100)], [(64, 512, 16, 16)], True, 1000), + ( + "split", + lambda x: MF.split(x, 5), + lambda x: torch.split(x, 5), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ("sigmoid", MF.sigmoid, TF.sigmoid, [(100, 100)], [(64, 512, 16, 16)], True, 1000), + ( + "softmax", + lambda x: MF.softmax(x, axis=1), + lambda x: TF.softmax(x, dim=1), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "softplus", + MF.softplus, + TF.softplus, + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "squeeze", + lambda x: MF.squeeze(x, 0), + lambda x: torch.squeeze(x, 0), + [(1, 100, 100)], + [(1, 64, 512, 16, 16)], + True, + 1000, + ), + ( + "stack", + MF.stack, + torch.stack, + [(100, 100), (100, 100)], + [(64, 512, 16, 16), (64, 512, 16, 16)], + False, + 10000, + ), + ( + "subtensor", + lambda x: x[0:20, 10:60], + lambda x: x[0:20, 10:60], + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "topk", + lambda x: MF.topk(x, 10), + lambda x: torch.topk(x, 10), + [(100, 100)], + [(1000, 1000)], + True, + 1000, + ), + ( + "tile", + lambda x: MF.tile(x, (2,) * len(x.shape)), + lambda x: torch.tile(x, (2,) * len(x.shape)), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "transpose", + lambda x: MF.transpose(x, list(range(len(x.shape)))[::-1]), + lambda x: torch.permute(x, list(range(len(x.shape)))[::-1]), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "where", + lambda x: MF.where(x > 0.5, x, x), + lambda x: torch.where(x > 0.5, x, x), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), + ( + "uniform", + lambda x: mge.random.uniform(0, 1, x.shape), + lambda x: torch.rand(x.shape, device="cuda"), + [(100, 100)], + [(64, 512, 16, 16)], + True, + 1000, + ), +] + + +def perf_func(func, inps, reps, unpack_inps, is_mge): + if is_mge: + mge._full_sync() + tik = time.time() + for _ in range(reps): + if unpack_inps: + out = func(*inps) + else: + out = func(inps) + mge._full_sync() + else: + torch.cuda.synchronize() + with torch.no_grad(): + tik = time.time() + for _ in range(reps): + if unpack_inps: + out = func(*inps) + else: + out = func(inps) + torch.cuda.synchronize() + return time.time() - tik + + +def get_avg_time(func, inps, reps, unpack_inps, is_mge): + # warm up + for _ in range(2): + t = perf_func(func, inps, reps, unpack_inps, is_mge) + + times = [] + for _ in range(5): + t = perf_func(func, inps, reps, unpack_inps, is_mge) + times.append(t) + return np.mean(times) + + +def get_perf_results(mge_func, torch_func, shapes, unpack_inps, reps): + inps = [np.random.randn(*shape) for shape in shapes] + + inps_mge = [mge.tensor(inp, dtype="float32") for inp in inps] + avg_time_mge = get_avg_time(mge_func, inps_mge, reps, unpack_inps, True) + + inps_torch = [torch.Tensor(inp).type(torch.float).cuda() for inp in inps] + avg_time_torch = get_avg_time(torch_func, inps_torch, reps, unpack_inps, False) + + return avg_time_mge, avg_time_torch + + +if __name__ == "__main__": + header = [ + "opr_name", + "time(mge/pytorch; small input)", + "time(mge/pytorch; large input)", + ] + table = [] + for case in test_cases: + assert len(case) == 7 + name, mge_func, torch_func, small_shapes, large_shapes, unpack_inps, reps = case + data = [] + data.append(name) + print("========== op: {}".format(name)) + + avg_time_mge, avg_time_torch = get_perf_results( + mge_func, torch_func, small_shapes, unpack_inps, reps + ) + print("mge time: {}".format(avg_time_mge)) + print("torch time: {}".format(avg_time_torch)) + data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) + + avg_time_mge, avg_time_torch = get_perf_results( + mge_func, torch_func, large_shapes, unpack_inps, reps + ) + print("mge time: {}".format(avg_time_mge)) + print("torch time: {}".format(avg_time_torch)) + data.append("{:.2f}".format(avg_time_mge / avg_time_torch)) + table.append(data) + print(tabulate(table, header, tablefmt="github")) diff --git a/imperative/python/megengine/tools/dump_with_testcase_mge.py b/imperative/python/megengine/tools/dump_with_testcase_mge.py new file mode 100755 index 000000000..1f53b2c26 --- /dev/null +++ b/imperative/python/megengine/tools/dump_with_testcase_mge.py @@ -0,0 +1,528 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import argparse +import os +import re +import struct + +import cv2 +import numpy as np + +import megengine as mge +import megengine.core._imperative_rt as rt +import megengine.core.tensor.megbrain_graph as G +from megengine import tensor +from megengine.core.ops import builtin +from megengine.utils import comp_graph_tools as cgtools + +logger = mge.get_logger(__name__) + + +def auto_reformat_image(args, path, data, dst_shape): + """reformat image to target shape + + :param data: image data as numpy array + :param dst_shape: target shape + """ + dim3_format = False # required input format does not contain batch + hwc_format = False # required input format is NHWC + + if not dst_shape: # input tensor shape is not predefined + if len(data.shape) == 2: + chl = 1 + h = data.shape[0] + w = data.shape[1] + else: + assert len(data.shape) == 3, "Input image must be of dimension 2 or 3" + h, w, chl = data.shape + dst_shape = (1, chl, h, w) + + if len(dst_shape) == 3: + dst_shape = (1,) + dst_shape + dim3_format = True + + assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) + chl = dst_shape[1] + if chl in [1, 3]: + n, c, h, w = dst_shape + dst_shape = (n, h, w, c) + else: + chl = dst_shape[3] + assert chl in [1, 3], "can not infer input format from shape: {}".format( + dst_shape + ) + hwc_format = True + + # dst_shape has now been normalized to NHWC format + + if args.resize_input: + h, w = dst_shape[1:3] + data = cv2.resize(data, (w, h)) + logger.info("input {} resized to {}".format(path, data.shape)) + + if chl == 1: + data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) + data = data[:, :, np.newaxis] + + assert data.ndim == 3 + data = data[np.newaxis] + # data normalized to NHWC format + + if not hwc_format: + data = np.transpose(data, (0, 3, 1, 2)) + + if dim3_format: + data = np.squeeze(data, 0) + + return data + + +def read_input_data(args, dst_shape, dtype, path, repeat): + def check_shape_equal(dst_shape, data_shape): + if len(dst_shape): + assert len(data_shape) == len( + dst_shape + ), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape) + + if data_shape[1:] != dst_shape[1:]: + logger.warning( + "dst_shape is {}; data_shape is {}".format(dst_shape, data_shape) + ) + + if path.startswith("#"): + assert not args.resize_input + assert not args.input_transform + spec = path + m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec) + assert m, "bad spec {}".format(spec) + + rng_min = float(m.group(1)) + rng_max = float(m.group(2)) + if m.group(3): + shape_str = m.group(3) + try: + shape = shape_str[1:].split(",") + if shape[-1].strip() == "...": + shape = shape[:-1] + shape.extend(list(dst_shape[len(shape) :])) + data_shape = tuple(map(int, shape)) + except ValueError as e: + raise ValueError("bad spec {}: {}".format(spec, e.args)) + else: + data_shape = dst_shape + + check_shape_equal(dst_shape, data_shape) + return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) + + # try to load image + data = cv2.imread(path, cv2.IMREAD_COLOR) + if data is None: + assert not args.resize_input + data = np.load(path) + assert isinstance(data, np.ndarray) + else: + # load image succeeds, so we expect input format is image format + data = auto_reformat_image(args, path, data, dst_shape) + + data = np.repeat(data, repeat, axis=0) + if repeat > 1: + logger.info( + "repeat input for {} times, data shape is {}".format(repeat, data.shape) + ) + + check_shape_equal(dst_shape, data.shape) + + if args.input_transform: + data = eval(args.input_transform, {"data": data, "np": np}) + + return data + + +def gen_one_testcase(args, inputs, spec): + paths = spec.split(";") + if len(paths) != len(inputs): + if len(paths) == 1 and paths[0].startswith("#"): + paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] + assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format( + inputs.keys(), paths + ) + if len(paths) == 1 and ":" not in paths[0]: + paths[0] = next(iter(inputs.keys())) + ":" + paths[0] + + ret = {} + for path in paths: + var, path = path.split(":") + if args.repeat: + repeat = args.repeat + else: + repeat = 1 + ret[var] = read_input_data( + args, inputs[var].shape, inputs[var].dtype, path, repeat + ) + return ret + + +def make_feeds(args): + ret = G.load_graph(args.input) + cg_rt, outputs = ret.graph, ret.output_vars_list + inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") + + inputs = {i.name: i for i in inputs} + if not args.no_assert: + + replace_varmap = {} + inp_map = {} + # replace var use InputNode + for name, var in inputs.items(): + inp = G.InputNode( + device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt + ) + replace_varmap[var] = inp.outputs[0] + inp_map[name] = inp + + new = cgtools.replace_vars(outputs, replace_varmap) + if isinstance(new, rt.VarNode): + new = list(new) + + output_nodes = [G.OutputNode(var) for var in new] + func = cg_rt.compile([node.outputs[0] for node in output_nodes]) + + def make_dev_tensor(value, dtype=None, device=None): + return tensor(value, dtype=dtype, device=device)._dev_tensor() + + def calculate(*args, **kwargs): + output_val = [] + # set inputs value + for name, var in inputs.items(): + val = kwargs.pop(name, None) + assert val is not None, "miss input name{}".format(name) + dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") + inp_map[name].set_value(dev_tensor) + + func.execute() + + for res in output_nodes: + output_val.append(res.get_value().numpy()) + return output_val + + def expect_name(var): + return "{}:expect".format(var.name) + + testcases = [] + + np.set_printoptions(precision=2, threshold=4, suppress=True) + + data_list = [] + for item in args.data: + if item.startswith("@"): + with open(item[1:], "r") as f: + data_list.extend([line.rstrip() for line in f if line.rstrip() != ""]) + else: + data_list.append(item) + + for inp_spec in data_list: + cur_testcase = gen_one_testcase(args, inputs, inp_spec) + assert len(cur_testcase) == len( + inputs + ), "required inputs: {}; given data: {}".format( + inputs.keys(), cur_testcase.keys() + ) + + if not args.no_assert: + outputs_get = calculate(**cur_testcase) + for var, val in zip(outputs, outputs_get): + cur_testcase[expect_name(var)] = val + logger.info( + "generate test groundtruth: var={} shape={} range=({}, {})" + " mean={} var={}".format( + var, val.shape, val.min(), val.max(), np.mean(val), np.var(val) + ) + ) + testcases.append(cur_testcase) + logger.info( + "add testcase: \n {}".format( + "\n ".join( + "{}: shape={} dtype={} range=({:.2f},{:.2f}) " + "mean={:.2f} sd={:.2f}".format( + k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) + ) + for k, v in sorted(cur_testcase.items()) + ) + ) + ) + + if not args.no_assert: + + def expect_shp(var): + ret = var.shape + if ret: + return ret + return testcases[0][expect_name(var)].shape + + def assert_equal(expect, real, **kwargs): + op = builtin.AssertEqual(**kwargs) + (res,) = G.apply_normal_varnode(op, expect, real) + return res + + verbose = not args.silent + + outputs_new = [] + for i in outputs: + device = rt.CompNode("xpux") + dtype = i.dtype + name = expect_name(i) + shape = expect_shp(i) + # make expect output as one input of model. + expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name) + # insert assert opr to check expect and real. + outputs_new.append( + assert_equal(expect_get, i, verbose=verbose, maxerr=args.maxerr,) + ) + inputs[expect_name(i)] = expect_get + outputs = outputs_new + + return {"outputs": outputs, "testcases": testcases} + + +def optimize_for_inference(args, outputs): + args_list = [ + "enable_io16xc32", + "enable_ioc16", + "enable_hwcd4", + "enable_nchw4", + "enable_nchw88", + "enable_nchw44", + "enable_nchw44_dot", + "enable_nchw32", + "enable_chwn4", + "enable_fuse_conv_bias_nonlinearity", + "enable_fuse_conv_bias_with_z", + "enable_fuse_preprocess", + ] + kwargs = {} + for k in args_list: + if getattr(args, k): + kwargs[k] = True + + if args.optimize_for_inference: + outputs = G.optimize_for_inference(outputs, **kwargs) + + return outputs + + +def main(): + parser = argparse.ArgumentParser( + description="Pack computing graph, input values and expected output " + "values into one file for checking correctness. README.md gives more " + "details on the usage", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("input", help="MegEngine dumped model file") + parser.add_argument("-o", "--output", help="output file", required=True) + parser.add_argument( + "-d", + "--data", + default=[], + action="append", + required=True, + help="Given input test data when input file is a network, " + "and current network output would be used as groundtruth. " + "The format is var0:file0;var1:file1... to specify data files for " + "input vars. It can also be #rand(min,max,shape...) for generating " + "random input data, for example, #rand(0,255), " + "#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means " + "the remaining part of the original shape. " + "If the shape is not specified, the shape of " + "corresponding input tensors in the network will be used. " + "If there is only one input var, its name can be omitted. " + "Each data file can either be an image which can be loaded by opencv, " + "or a pickled numpy.ndarray. " + "This option can be given multiple times to add multiple testcases. " + " *NOTE* " + "If you start the data with the letter @, the rest should be a " + "filename, and each line in the file should be a single datum in " + "the format described above. ", + ) + parser.add_argument( + "--repeat", + type=int, + default=1, + help="Specify how many times the input image is repeated. " + "Useful when running benchmark for batch size other than one. " + "Have no effect on randomly generated input data.", + ) + parser.add_argument( + "--silent", + action="store_true", + help="set verbose to False in asserti_equal opr", + ) + parser.add_argument( + "--optimize-for-inference", + action="store_true", + help="enable optimization for inference", + ) + parser.add_argument( + "--no-assert", + action="store_true", + help="do not insert assert_equal opr to check result; " + "this option is useful for benchmarking", + ) + parser.add_argument( + "--maxerr", + type=float, + default=1e-4, + help="max error for assert_equal check during runtime", + ) + parser.add_argument( + "--resize-input", + action="store_true", + help="resize input image to fit input var shape", + ) + parser.add_argument( + "--input-transform", + help="a python expression to transform the input data. " + "Example: data / np.std(data)", + ) + parser.add_argument( + "--discard-var-name", + action="store_true", + help="discard variable and param names in the " "generated output", + ) + parser.add_argument( + "--output-strip-info", action="store_true", help="output code strip information" + ) + parser.add_argument( + "--enable-io16xc32", + action="store_true", + help="transform the mode to float16 io float32 compute", + ) + parser.add_argument( + "--enable-ioc16", + action="store_true", + help="transform the dtype of the model to float16 io " "and compute", + ) + parser.add_argument( + "--enable-fuse-conv-bias-nonlinearity", + action="store_true", + help="fuse convolution bias and nonlinearity opr to a " + "conv_bias opr and compute", + ) + parser.add_argument( + "--enable-hwcd4", + action="store_true", + help="transform the model format from NCHW to NHWCD4 " + "for inference; you may need to disable CUDA and set " + "MGB_USE_MEGDNN_DBG=2", + ) + parser.add_argument( + "--enable-nchw4", + action="store_true", + help="transform the model format from NCHW to NCHW4 " "for inference", + ) + parser.add_argument( + "--enable-nchw88", + action="store_true", + help="transform the model format from NCHW to NCHW88 " "for inference", + ) + parser.add_argument( + "--enable-nchw44", + action="store_true", + help="transform the model format from NCHW to NCHW44 " "for inference", + ) + parser.add_argument( + "--enable-nchw44-dot", + action="store_true", + help="transform the model format from NCHW to NCHW44_DOT " + "for optimizing armv8.2 dot in inference", + ) + parser.add_argument( + "--enable-nchw32", + action="store_true", + help="transform the model format from NCHW4 to NCHW32 " + "for inference on nvidia TensoCore", + ) + parser.add_argument( + "--enable-chwn4", + action="store_true", + help="transform the model format to CHWN4 " + "for inference, mainly used for nvidia tensorcore", + ) + parser.add_argument( + "--enable-fuse-conv-bias-with-z", + action="store_true", + help="fuse conv_bias with z input for inference on " + "nvidia GPU (this optimization pass will result in mismatch " + "of the precision of output of training and inference)", + ) + parser.add_argument( + "--enable-fuse-preprocess", + action="store_true", + help="fuse astype\pad_channel\dimshuffle and etc opr " "from h2d opr", + ) + args = parser.parse_args() + + feeds = make_feeds(args) + + assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty" + + output_mgbvars = feeds["outputs"] + output_mgbvars = optimize_for_inference(args, output_mgbvars) + + inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") + inputs = sorted((i.name, i.dtype) for i in inputs) + + if args.discard_var_name: + sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) + else: + sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) + + strip_info_file = args.output + ".json" if args.output_strip_info else None + + with open(args.output, "wb") as fout: + fout.write(b"mgbtest0") + fout.write(struct.pack("I", len(feeds["testcases"]))) + dump_content, stat = G.dump_graph( + output_mgbvars, + append_json=True, + strip_info_file=strip_info_file, + **sereg_kwargs, + ) + fout.write(dump_content) + + logger.info( + "graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format( + stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 + ) + ) + + def make_dev_tensor(value, dtype=None, device=None): + return tensor(value, dtype=dtype, device=device)._dev_tensor() + + for testcase in feeds["testcases"]: + assert isinstance(testcase, dict) + cg = G.Graph() + output_mgbvars = [] + for name, dtype in inputs: + output_mgbvars.append( + cg.make_const( + make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux") + ) + ) + assert not testcase, "extra inputs provided in testcase: {}".format( + testcase.keys() + ) + with open(args.output, "ab") as fout: + dump_content, _ = G.dump_graph( + output_mgbvars, strip_info_file=strip_info_file, append_json=True + ) + fout.write(dump_content) + + +if __name__ == "__main__": + main() -- GitLab