# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import os import platform import sys import pytest from megengine.core import _config as config from megengine.core import _trace_option as trace_option from megengine.core import get_option from megengine.core._imperative_rt.core2 import ( _get_amp_dtype_autocast, _get_amp_high_prec_dtype, _get_amp_low_prec_dtype, _get_convert_inputs, ) from megengine.core.tensor import amp from megengine.device import get_device_count sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) _ngpu = get_device_count("gpu") @pytest.fixture(autouse=True) def skip_by_ngpu(request): if request.node.get_closest_marker("require_ngpu"): require_ngpu = int(request.node.get_closest_marker("require_ngpu").args[0]) if require_ngpu > _ngpu: pytest.skip("skipped for ngpu unsatisfied: {}".format(require_ngpu)) @pytest.fixture(autouse=True) def skip_distributed(request): if request.node.get_closest_marker("distributed_isolated"): if platform.system() in ("Windows", "Darwin"): pytest.skip( "skipped for distributed unsupported at platform: {}".format( platform.system() ) ) @pytest.fixture(autouse=True) def run_around_tests(): env_vars1 = { "symbolic_shape": trace_option.use_symbolic_shape(), "async_level": get_option("async_level"), "enable_drop": get_option("enable_drop"), "max_recompute_time": get_option("max_recompute_time"), "catch_worker_execption": get_option("catch_worker_execption"), "enable_host_compute": get_option("enable_host_compute"), # "record_computing_path": get_option("record_computing_path"), "disable_memory_forwarding": get_option("disable_memory_forwarding"), "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"), "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"), "dtr_eviction_threshold": get_option("dtr_eviction_threshold"), "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"), "benchmark_kernel": config.benchmark_kernel, "deterministic_kernel": config.deterministic_kernel, "compute_mode": config._compute_mode, "conv_format": config._conv_format, "amp_enabled": amp.enabled, "convert_inputs": _get_convert_inputs(), "amp_dtype_autocast": _get_amp_dtype_autocast(), "amp_high_prec_dtype": _get_amp_high_prec_dtype(), "amp_low_prec_dtype": _get_amp_low_prec_dtype(), } yield env_vars2 = { "symbolic_shape": trace_option.use_symbolic_shape(), "async_level": get_option("async_level"), "enable_drop": get_option("enable_drop"), "max_recompute_time": get_option("max_recompute_time"), "catch_worker_execption": get_option("catch_worker_execption"), "enable_host_compute": get_option("enable_host_compute"), # "record_computing_path": get_option("record_computing_path"), "disable_memory_forwarding": get_option("disable_memory_forwarding"), "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"), "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"), "dtr_eviction_threshold": get_option("dtr_eviction_threshold"), "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"), "benchmark_kernel": config.benchmark_kernel, "deterministic_kernel": config.deterministic_kernel, "compute_mode": config._compute_mode, "conv_format": config._conv_format, "amp_enabled": amp.enabled, "convert_inputs": _get_convert_inputs(), "amp_dtype_autocast": _get_amp_dtype_autocast(), "amp_high_prec_dtype": _get_amp_high_prec_dtype(), "amp_low_prec_dtype": _get_amp_low_prec_dtype(), } for key in env_vars1: assert ( env_vars1[key] == env_vars2[key] ), "{} have been changed after test".format(key)