From 597a1e791bc931b4b23adf0ae4d8cc47cf64ec86 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 26 Apr 2022 21:06:35 +0800 Subject: [PATCH] refactor(imperative): add interface to clear algorithm cache GitOrigin-RevId: 662618954bc5dee254f294f4ce8b2e4efb95d87b --- dnn/include/megdnn/algorithm_cache.h | 2 +- imperative/python/megengine/core/_config.py | 5 +- .../megengine/functional/debug_param.py | 28 +- imperative/python/src/tensor.cpp | 3 + .../integration/test_correctness_mnistnet.py | 289 ------------------ .../test/integration/test_dp_correctness.py | 9 +- .../python/test/unit/module/test_conv.py | 21 +- .../python/test/unit/quantization/test_op.py | 3 +- .../test/unit/utils/test_network_node.py | 33 +- 9 files changed, 50 insertions(+), 343 deletions(-) delete mode 100644 imperative/python/test/integration/test_correctness_mnistnet.py diff --git a/dnn/include/megdnn/algorithm_cache.h b/dnn/include/megdnn/algorithm_cache.h index 1180dcaa9..1950547d4 100644 --- a/dnn/include/megdnn/algorithm_cache.h +++ b/dnn/include/megdnn/algorithm_cache.h @@ -71,7 +71,7 @@ public: MGE_WIN_DECLSPEC_FUC Result get(const Key& key); - void clear(); + MGE_WIN_DECLSPEC_FUC void clear(); private: struct Hash { diff --git a/imperative/python/megengine/core/_config.py b/imperative/python/megengine/core/_config.py index b55c5c325..c6a9d0ac8 100644 --- a/imperative/python/megengine/core/_config.py +++ b/imperative/python/megengine/core/_config.py @@ -9,7 +9,7 @@ import os from contextlib import contextmanager -from ._imperative_rt.core2 import get_option, set_option +from ._imperative_rt.core2 import _clear_algorithm_cache, get_option, set_option __compute_mode = "default" __conv_format = "default" @@ -44,6 +44,9 @@ def benchmark_kernel(mod): @benchmark_kernel.setter def benchmark_kernel(mod, option: bool): global _benchmark_kernel + # try different strategy, then clear algorithm cache + if option != _benchmark_kernel: + _clear_algorithm_cache() _benchmark_kernel = option diff --git a/imperative/python/megengine/functional/debug_param.py b/imperative/python/megengine/functional/debug_param.py index d83f2e8d2..4f2f99a46 100644 --- a/imperative/python/megengine/functional/debug_param.py +++ b/imperative/python/megengine/functional/debug_param.py @@ -9,6 +9,7 @@ import os from ..core import _config +from ..core._imperative_rt.core2 import _clear_algorithm_cache from ..core.ops import builtin from ..logger import get_logger from ..utils.deprecation import deprecated @@ -52,7 +53,6 @@ def set_execution_strategy(option): * "HEURISTIC": uses heuristic to choose the fastest algorithm. * "PROFILE": runs possible algorithms on a real device to find the best one. * "REPRODUCIBLE": uses algorithms that are reproducible. - * "OPTIMIZED": uses algorithms that are optimized. The default strategy is "HEURISTIC", these options can be combined to form a combination option, e.g. PROFILE_REPRODUCIBLE is a combination @@ -70,22 +70,25 @@ def set_execution_strategy(option): It can also be set through the environment variable ``MEGENGINE_EXECUTION_STRATEGY``. """ - + _benchmark_kernel = False + _deterministic_kernel = False if isinstance(option, Strategy): - _config._benchmark_kernel = ( + _benchmark_kernel = ( True if option & _valid_string_option["PROFILE"] != Strategy(0) else False ) - _config._deterministic_kernel = ( + _deterministic_kernel = ( True if option & _valid_string_option["REPRODUCIBLE"] != Strategy(0) else False ) + if _benchmark_kernel != _config._benchmark_kernel: + _clear_algorithm_cache() + _config._benchmark_kernel = _benchmark_kernel + _config._deterministic_kernel = _deterministic_kernel return assert isinstance(option, str) - _config._benchmark_kernel = False - _config._deterministic_kernel = False for opt in option.split("_"): if not opt in _valid_string_option: raise ValueError( @@ -93,10 +96,12 @@ def set_execution_strategy(option): _valid_string_option.keys() ) ) - _config._benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE - _config._deterministic_kernel |= ( - _valid_string_option[opt] == Strategy.REPRODUCIBLE - ) + _benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE + _deterministic_kernel |= _valid_string_option[opt] == Strategy.REPRODUCIBLE + if _benchmark_kernel != _config._benchmark_kernel: + _clear_algorithm_cache() + _config._benchmark_kernel = _benchmark_kernel + _config._deterministic_kernel = _deterministic_kernel @deprecated(version="1.3", reason="use get_execution_strategy() instead") @@ -107,6 +112,3 @@ def get_conv_execution_strategy() -> str: @deprecated(version="1.3", reason="use set_execution_strategy() instead") def set_conv_execution_strategy(option: str): return set_execution_strategy(option) - - -set_execution_strategy(os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC")) diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index efbb05093..c7b83ea1b 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -26,6 +26,7 @@ #include "megbrain/opr/io.h" #include "megbrain/plugin/profiler.h" #include "megbrain/utils/stats.h" +#include "megdnn/algorithm_cache.h" #include "./common.h" #include "./grad.h" @@ -1428,6 +1429,8 @@ void init_tensor(py::module m) { return set_amp_prec_dtype(false, dtype_name); }); + m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); }); + py::register_exception(m, "TraceError"); } diff --git a/imperative/python/test/integration/test_correctness_mnistnet.py b/imperative/python/test/integration/test_correctness_mnistnet.py deleted file mode 100644 index ce33b5416..000000000 --- a/imperative/python/test/integration/test_correctness_mnistnet.py +++ /dev/null @@ -1,289 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import os -import re -import subprocess -import sys - -import numpy as np -import pytest - -import megengine as mge -import megengine.autodiff as ad -import megengine.functional as F -from megengine import jit -from megengine.core._trace_option import set_symbolic_shape -from megengine.core.ops import builtin -from megengine.core.tensor.utils import make_shape_tuple -from megengine.functional.debug_param import set_execution_strategy -from megengine.jit import SublinearMemoryConfig -from megengine.module import ( - AdaptiveAvgPool2d, - AvgPool2d, - BatchNorm2d, - Conv2d, - Linear, - Module, -) -from megengine.optimizer import SGD -from megengine.tensor import Tensor - -Strategy = builtin.ops.Convolution.Strategy - - -def get_gpu_name(): - try: - gpu_info = subprocess.check_output( - ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"] - ) - gpu_info = gpu_info.decode("ascii").split("\n")[0] - except: - gpu_info = "None" - return gpu_info - - -def get_cpu_name(): - cpu_info = "None" - try: - cpu_info = subprocess.check_output(["cat", "/proc/cpuinfo"]).decode("ascii") - for line in cpu_info.split("\n"): - if "model name" in line: - return re.sub(".*model name.*:", "", line, 1).strip() - except: - pass - return cpu_info - - -def get_xpu_name(): - if mge.is_cuda_available(): - return get_gpu_name() - else: - return get_cpu_name() - - -class MnistNet(Module): - def __init__(self, has_bn=False, use_adaptive_pooling=False): - super().__init__() - self.conv0 = Conv2d(1, 20, kernel_size=5, bias=True) - if use_adaptive_pooling: - self.pool0 = AdaptiveAvgPool2d(12) - else: - self.pool0 = AvgPool2d(2) - self.conv1 = Conv2d(20, 20, kernel_size=5, bias=True) - self.pool1 = AvgPool2d(2) - self.fc0 = Linear(20 * 4 * 4, 500, bias=True) - self.fc1 = Linear(500, 10, bias=True) - self.bn0 = None - self.bn1 = None - if has_bn: - self.bn0 = BatchNorm2d(20) - self.bn1 = BatchNorm2d(20) - - def forward(self, x): - x = self.conv0(x) - if self.bn0: - x = self.bn0(x) - x = F.relu(x) - x = self.pool0(x) - x = self.conv1(x) - if self.bn1: - x = self.bn1(x) - x = F.relu(x) - x = self.pool1(x) - x = F.flatten(x, 1) - x = self.fc0(x) - x = F.relu(x) - x = self.fc1(x) - return x - - -def train(data, label, net, opt, gm): - with gm: - pred = net(data) - loss = F.nn.cross_entropy(pred, label) - gm.backward(loss) - return loss - - -def update_model(model_path): - """ - Update the dumped model with test cases for new reference values. - - The model with pre-trained weights is trained for one iter with the test data attached. - The loss and updated net state dict is dumped. - - .. code-block:: python - - from test_correctness import update_model - update_model('mnist_model_with_test.mge') # for gpu - update_model('mnist_model_with_test_cpu.mge') # for cpu - - """ - net = MnistNet(has_bn=True) - checkpoint = mge.load(model_path) - net.load_state_dict(checkpoint["net_init"]) - lr = checkpoint["sgd_lr"] - opt = SGD(net.parameters(), lr=lr) - gm = ad.GradManager().attach(net.parameters()) - - data = Tensor(checkpoint["data"], dtype=np.float32) - label = Tensor(checkpoint["label"], dtype=np.int32) - - opt.clear_grad() - loss = train(data, label, net, opt, gm) - opt.step() - - xpu_name = get_xpu_name() - - checkpoint.update( - {"net_updated": net.state_dict(), "loss": loss.numpy(), "xpu": xpu_name} - ) - mge.save(checkpoint, model_path) - - -def run_train( - model_path, - use_jit, - use_symbolic, - sublinear_memory_config=None, - max_err=None, - use_adaptive_pooling=False, -): - - """ - Load the model with test cases and run the training for one iter. - The loss and updated weights are compared with reference value to verify the correctness. - - Dump a new file with updated result by calling update_model - if you think the test fails due to numerical rounding errors instead of bugs. - Please think twice before you do so. - - """ - net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling) - checkpoint = mge.load(model_path) - net.load_state_dict(checkpoint["net_init"]) - lr = checkpoint["sgd_lr"] - opt = SGD(net.parameters(), lr=lr) - gm = ad.GradManager().attach(net.parameters()) - - data = Tensor(checkpoint["data"], dtype=np.float32) - label = Tensor(checkpoint["label"], dtype=np.int32) - - if max_err is None: - max_err = 1e-5 - - train_func = train - if use_jit: - train_func = jit.trace( - train_func, - symbolic=use_symbolic, - sublinear_memory_config=sublinear_memory_config, - ) - - opt.clear_grad() - loss = train_func(data, label, net, opt, gm) - opt.step() - - np.testing.assert_allclose(loss.numpy(), checkpoint["loss"], atol=max_err) - - for param, param_ref in zip( - net.state_dict().items(), checkpoint["net_updated"].items() - ): - assert param[0] == param_ref[0] - if "bn" in param[0]: - ref = param_ref[1].reshape(param[1].shape) - np.testing.assert_allclose(param[1], ref, atol=max_err) - else: - np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) - - -def run_eval( - model_path, - use_symbolic, - sublinear_memory_config=None, - max_err=None, - use_adaptive_pooling=False, -): - - """ - Load the model with test cases and run the training for one iter. - The loss and updated weights are compared with reference value to verify the correctness. - - Dump a new file with updated result by calling update_model - if you think the test fails due to numerical rounding errors instead of bugs. - Please think twice before you do so. - - """ - net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling) - checkpoint = mge.load(model_path) - net.load_state_dict(checkpoint["net_init"]) - - data = Tensor(checkpoint["data"], dtype=np.float32) - - def eval_fun(data, *, net=None): - pred = net(data) - return pred - - refer_value = eval_fun(data, net=net) - eval_fun = jit.trace(eval_fun, symbolic=use_symbolic) - - for _ in range(3): - new_value = eval_fun(data, net=net) - np.testing.assert_allclose(new_value.numpy(), refer_value.numpy(), atol=max_err) - - -@pytest.mark.skip(reason="close it when cu111 ci") -def test_correctness(): - if mge.is_cuda_available(): - model_name = "mnist_model_with_test.mge" - else: - model_name = "mnist_model_with_test_cpu.mge" - model_path = os.path.join(os.path.dirname(__file__), model_name) - set_execution_strategy(Strategy.HEURISTIC | Strategy.REPRODUCIBLE) - - run_train(model_path, False, False, max_err=1e-5) - run_train(model_path, True, False, max_err=1e-5) - run_train(model_path, True, True, max_err=1e-5) - - # sublinear - config = SublinearMemoryConfig(genetic_nr_iter=10) - run_train( - model_path, True, True, sublinear_memory_config=config, max_err=1e-5, - ) - - run_eval(model_path, False, max_err=1e-7) - run_eval(model_path, True, max_err=1e-7) - - -@pytest.mark.skip(reason="close it when cu111 ci") -def test_correctness_use_adaptive_pooling(): - if mge.is_cuda_available(): - model_name = "mnist_model_with_test.mge" - else: - model_name = "mnist_model_with_test_cpu.mge" - model_path = os.path.join(os.path.dirname(__file__), model_name) - set_execution_strategy("HEURISTIC_REPRODUCIBLE") - - run_train(model_path, False, False, max_err=1e-5, use_adaptive_pooling=True) - run_train(model_path, True, False, max_err=1e-5, use_adaptive_pooling=True) - run_train(model_path, True, True, max_err=1e-5, use_adaptive_pooling=True) - - # sublinear - config = SublinearMemoryConfig(genetic_nr_iter=10) - run_train( - model_path, - True, - True, - sublinear_memory_config=config, - max_err=1e-5, - use_adaptive_pooling=True, - ) - - run_eval(model_path, False, max_err=1e-7, use_adaptive_pooling=True) - run_eval(model_path, True, max_err=1e-7, use_adaptive_pooling=True) diff --git a/imperative/python/test/integration/test_dp_correctness.py b/imperative/python/test/integration/test_dp_correctness.py index f4e986329..06b151d44 100644 --- a/imperative/python/test/integration/test_dp_correctness.py +++ b/imperative/python/test/integration/test_dp_correctness.py @@ -7,11 +7,8 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import os -import platform import re import subprocess -import sys -from math import ceil import numpy as np import pytest @@ -20,8 +17,6 @@ import megengine as mge import megengine.autodiff as ad import megengine.distributed as dist import megengine.functional as F -from megengine.device import get_default_device, set_default_device -from megengine.functional.debug_param import set_execution_strategy from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module from megengine.optimizer import SGD from megengine.tensor import Tensor @@ -198,5 +193,7 @@ def run_test( def test_dp_correctness(): model_name = "mnist_model_with_test.mge" model_path = os.path.join(os.path.dirname(__file__), model_name) - set_execution_strategy("HEURISTIC_REPRODUCIBLE") + old = mge.config.deterministic_kernel + mge.config.deterministic_kernel = True run_test(model_path, False, False, max_err=5e-5) + mge.config.deterministic_kernel = old diff --git a/imperative/python/test/unit/module/test_conv.py b/imperative/python/test/unit/module/test_conv.py index e782a077a..436324df7 100644 --- a/imperative/python/test/unit/module/test_conv.py +++ b/imperative/python/test/unit/module/test_conv.py @@ -11,21 +11,9 @@ import itertools import numpy as np import pytest +import megengine as mge import megengine.module as M -from megengine import Parameter, tensor -from megengine.functional.debug_param import ( - get_execution_strategy, - set_execution_strategy, -) -from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d - - -@pytest.fixture -def reproducible(): - old = get_execution_strategy() - set_execution_strategy("HEURISTIC_REPRODUCIBLE") - yield - set_execution_strategy(old) +from megengine import tensor # NOTE: test in module for convenience. should really test in functional @@ -33,7 +21,9 @@ def reproducible(): "name", ["Conv1d", "Conv2d", "Conv3d", "ConvTranspose2d", "ConvTranspose3d", "LocalConv2d"], ) -def test_conv_dtype_promotion(name, reproducible): +def test_conv_dtype_promotion(name): + old = mge.config.deterministic_kernel + mge.config.deterministic_kernel = True N, Ci, Co, K = 2, 16, 32, 3 S = (7,) * int(name[-2]) if "Local" in name: @@ -42,3 +32,4 @@ def test_conv_dtype_promotion(name, reproducible): m = getattr(M, name)(Ci, Co, K) x = tensor(np.random.random(size=(N, Ci) + S).astype("float16")) np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy()) + mge.config.deterministic_kernel = old diff --git a/imperative/python/test/unit/quantization/test_op.py b/imperative/python/test/unit/quantization/test_op.py index 53929bb12..ab21c93c3 100644 --- a/imperative/python/test/unit/quantization/test_op.py +++ b/imperative/python/test/unit/quantization/test_op.py @@ -255,9 +255,8 @@ def test_conv_bias_int4(): run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") -@pytest.mark.require_ngpu(1) @pytest.mark.skipif( - get_cuda_compute_capability(0) < 61, + get_device_count("gpu") > 0 and get_cuda_compute_capability(0) < 61, reason="does not support int8 when gpu compute capability less than 6.1", ) def test_conv_transpose2d(): diff --git a/imperative/python/test/unit/utils/test_network_node.py b/imperative/python/test/unit/utils/test_network_node.py index 7ec1b7a5e..fd2361e46 100644 --- a/imperative/python/test/unit/utils/test_network_node.py +++ b/imperative/python/test/unit/utils/test_network_node.py @@ -5,6 +5,7 @@ import platform import numpy as np import pytest +import megengine as mge import megengine.core.tensor.dtype as dtype import megengine.core.tensor.megbrain_graph as G import megengine.functional as F @@ -18,10 +19,6 @@ from megengine.device import ( get_device_count, is_cuda_available, ) -from megengine.functional.debug_param import ( - get_execution_strategy, - set_execution_strategy, -) from megengine.functional.external import tensorrt_runtime_opr from megengine.jit.tracing import trace from megengine.tensor import Tensor @@ -110,25 +107,30 @@ def test_matinv(): @pytest.mark.parametrize( - "execution_strategy", ["HEURISTIC_REPRODUCIBLE", "PROFILE_REPRODUCIBLE"] + "benchmark_kernel, max_err", [(False, None), (True, 1e-5)], ) -def test_matmul(execution_strategy): +def test_matmul(monkeypatch, benchmark_kernel, max_err): + if get_device_count("gpu") == 0 and benchmark_kernel: + return + monkeypatch.setenv("MGE_FASTRUN_CACHE_TYPE", "MEMORY") + old1, old2 = ( + mge.config.benchmark_kernel, + mge.config.deterministic_kernel, + ) + mge.config.benchmark_kernel = benchmark_kernel + mge.config.deterministic_kernel = True + @trace(symbolic=True, capture_as_const=True) def fwd(data1, data2): return F.matmul(data1, data2) - old = get_execution_strategy() - set_execution_strategy(execution_strategy) - - max_err = None - if execution_strategy == "PROFILE_REPRODUCIBLE": - max_err = 1e-5 - data1 = Tensor(np.random.random((32, 64))) data2 = Tensor(np.random.random((64, 16))) result = fwd(data1, data2) check_pygraph_dump(fwd, [data1, data2], [result], max_err=max_err) - set_execution_strategy(old) + mge.config.benchmark_kernel = old1 + mge.config.deterministic_kernel = old2 + monkeypatch.delenv("MGE_FASTRUN_CACHE_TYPE", raising=False) def test_batchmatmul(): @@ -290,9 +292,8 @@ def test_deformable_ps_roi_pooling(): check_pygraph_dump(fwd, [inp, rois, trans], [result]) -@pytest.mark.require_ngpu(1) @pytest.mark.skipif( - get_cuda_compute_capability(0) < 61, + get_device_count("gpu") > 0 and get_cuda_compute_capability(0) < 61, reason="does not support int8 when gpu compute capability less than 6.1", ) def test_convbias(): -- GitLab