提交 cc6084b4 编写于 作者: M Megvii Engine Team 提交者: Wanwan1996

revert: chore(mge/misc): api converage

This reverts commit --

77fd432cb24678c2351a7656ddd44339618f9cb6
0d7126641b467f4d33b71c52824c0b384830f1c2
559d205a8c8c0d658fc2e7eacd4d536e5aead4f0
6d8e7398e4a06ee90c0bdbebec3f0290c12b5aa4
b72517bfe9c4b4e3efde5badc95fdf89b0ea48cd

GitOrigin-RevId: 18dc4ce999485ae15d37ae09bd3d799b968fd261
上级 2a4295fb
...@@ -19,11 +19,11 @@ logger = get_logger(__name__) ...@@ -19,11 +19,11 @@ logger = get_logger(__name__)
backwarding_grad_manager = None backwarding_grad_manager = None
def _get_backwarding_grad_manager(): def get_backwarding_grad_manager():
return backwarding_grad_manager return backwarding_grad_manager
class _AttachSpec: class AttachSpec:
__slots__ = "tensor", "callbacks" __slots__ = "tensor", "callbacks"
...@@ -118,7 +118,7 @@ class GradManager: ...@@ -118,7 +118,7 @@ class GradManager:
""" """
def __init__(self): def __init__(self):
self._attach_specs = {} # id(Tensor) -> _AttachSpec self._attach_specs = {} # id(Tensor) -> AttachSpec
self._recording = False self._recording = False
self._grad = None self._grad = None
self._after_backward_callback = [] self._after_backward_callback = []
...@@ -200,7 +200,7 @@ class GradManager: ...@@ -200,7 +200,7 @@ class GradManager:
if self is not None: if self is not None:
del self._attach_specs[key] del self._attach_specs[key]
spec = _AttachSpec() spec = AttachSpec()
spec.tensor = weakref.ref(tensor, deleter) spec.tensor = weakref.ref(tensor, deleter)
spec.callbacks = [] spec.callbacks = []
return spec return spec
...@@ -354,22 +354,22 @@ class GradManager: ...@@ -354,22 +354,22 @@ class GradManager:
def __or__(self, other): def __or__(self, other):
if isinstance(other, GradManager): if isinstance(other, GradManager):
return _GradManagerGroup([self, other]) return GradManagerGroup([self, other])
return NotImplemented return NotImplemented
__ror__ = __or__ __ror__ = __or__
class _GradManagerGroup: class GradManagerGroup:
def __init__(self, gms) -> None: def __init__(self, gms) -> None:
self._gms = list(gms) self._gms = list(gms)
def merge_with(self, other): def merge_with(self, other):
if isinstance(other, GradManager): if isinstance(other, GradManager):
other = _GradManagerGroup([other]) other = GradManagerGroup([other])
elif not isinstance(other, _GradManagerGroup): elif not isinstance(other, GradManagerGroup):
return NotImplemented return NotImplemented
return _GradManagerGroup([*self._gms, *other._gms]) return GradManagerGroup([*self._gms, *other._gms])
__or__ = merge_with __or__ = merge_with
__ror__ = merge_with __ror__ = merge_with
......
...@@ -34,7 +34,7 @@ logger = get_logger(__name__) ...@@ -34,7 +34,7 @@ logger = get_logger(__name__)
GLOBAL_TIMEOUT = 5 GLOBAL_TIMEOUT = 5
def _raise_timeout_error(): def raise_timeout_error():
raise RuntimeError("dataloader timeout") raise RuntimeError("dataloader timeout")
...@@ -191,7 +191,7 @@ class DataLoader: ...@@ -191,7 +191,7 @@ class DataLoader:
) )
class _PreLoader: class PreLoader:
def __init__(self, loader, preload): def __init__(self, loader, preload):
self.dataset = loader.dataset self.dataset = loader.dataset
self.sampler = loader.sampler self.sampler = loader.sampler
...@@ -319,7 +319,7 @@ class _ParallelDataLoaderIter: ...@@ -319,7 +319,7 @@ class _ParallelDataLoaderIter:
if success: if success:
return data return data
else: else:
_raise_timeout_error() raise_timeout_error()
else: else:
while True: while True:
success, data = self._try_get_data() success, data = self._try_get_data()
...@@ -417,7 +417,7 @@ class _ParallelDataLoaderIter: ...@@ -417,7 +417,7 @@ class _ParallelDataLoaderIter:
self._shutdown_workers() self._shutdown_workers()
class _BaseMapDataLoaderIter(_PreLoader): class _BaseMapDataLoaderIter(PreLoader):
def __init__(self, loader, preload): def __init__(self, loader, preload):
super().__init__(loader, preload) super().__init__(loader, preload)
...@@ -510,7 +510,7 @@ def get_worker_info(): ...@@ -510,7 +510,7 @@ def get_worker_info():
return _worker_info return _worker_info
class _BaseStreamDataLoaderIter(_PreLoader): class _BaseStreamDataLoaderIter(PreLoader):
def __init__(self, loader, preload): def __init__(self, loader, preload):
super().__init__(loader, preload) super().__init__(loader, preload)
self.dataset_iter = iter(self.dataset) self.dataset_iter = iter(self.dataset)
...@@ -552,7 +552,7 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): ...@@ -552,7 +552,7 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
timer.cancel() timer.cancel()
waited_time = time.time() - start_time waited_time = time.time() - start_time
if waited_time > self.timeout: if waited_time > self.timeout:
_raise_timeout_error() raise_timeout_error()
return raw_data return raw_data
def _get_next_batch(self): def _get_next_batch(self):
...@@ -583,7 +583,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter, _ParallelDataLoad ...@@ -583,7 +583,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter, _ParallelDataLoad
place_holder = [next(self.dataset_iter)] place_holder = [next(self.dataset_iter)]
waited_time = time.time() - start_time waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout: if self.timeout > 0 and waited_time > self.timeout:
_raise_timeout_error() raise_timeout_error()
place_holder = self._get_remaind_data(place_holder) place_holder = self._get_remaind_data(place_holder)
else: else:
place_holder = next(self._sampler_iter) place_holder = next(self._sampler_iter)
......
...@@ -21,7 +21,7 @@ def _count_visible_keypoints(anno): ...@@ -21,7 +21,7 @@ def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
def _has_valid_annotation(anno, order): def has_valid_annotation(anno, order):
# if it"s empty, there is no annotation # if it"s empty, there is no annotation
if len(anno) == 0: if len(anno) == 0:
return False return False
...@@ -101,7 +101,7 @@ class COCO(VisionDataset): ...@@ -101,7 +101,7 @@ class COCO(VisionDataset):
anno = [ anno = [
obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0
] ]
if _has_valid_annotation(anno, order): if has_valid_annotation(anno, order):
ids.append(img_id) ids.append(img_id)
self.img_to_anns[img_id] = anno self.img_to_anns[img_id] = anno
else: else:
......
...@@ -140,17 +140,17 @@ class MNIST(VisionDataset): ...@@ -140,17 +140,17 @@ class MNIST(VisionDataset):
# load raw files and transform them into meta data and datasets Tuple(np.array) # load raw files and transform them into meta data and datasets Tuple(np.array)
logger.info("process the raw files of %s set...", "train" if train else "test") logger.info("process the raw files of %s set...", "train" if train else "test")
if train: if train:
meta_data_images, images = _parse_idx3( meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[0]) os.path.join(self.root, self.raw_file_name[0])
) )
meta_data_labels, labels = _parse_idx1( meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[1]) os.path.join(self.root, self.raw_file_name[1])
) )
else: else:
meta_data_images, images = _parse_idx3( meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[2]) os.path.join(self.root, self.raw_file_name[2])
) )
meta_data_labels, labels = _parse_idx1( meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[3]) os.path.join(self.root, self.raw_file_name[3])
) )
...@@ -161,7 +161,7 @@ class MNIST(VisionDataset): ...@@ -161,7 +161,7 @@ class MNIST(VisionDataset):
self.arrays = (images, labels.astype(np.int32)) self.arrays = (images, labels.astype(np.int32))
def _parse_idx3(idx3_file): def parse_idx3(idx3_file):
# parse idx3 file to meta data and data in numpy array (images) # parse idx3 file to meta data and data in numpy array (images)
logger.debug("parse idx3 file %s ...", idx3_file) logger.debug("parse idx3 file %s ...", idx3_file)
assert idx3_file.endswith(".gz") assert idx3_file.endswith(".gz")
...@@ -187,7 +187,7 @@ def _parse_idx3(idx3_file): ...@@ -187,7 +187,7 @@ def _parse_idx3(idx3_file):
return meta_data, images return meta_data, images
def _parse_idx1(idx1_file): def parse_idx1(idx1_file):
# parse idx1 file to meta data and data in numpy array (labels) # parse idx1 file to meta data and data in numpy array (labels)
logger.debug("parse idx1 file %s ...", idx1_file) logger.debug("parse idx1 file %s ...", idx1_file)
assert idx1_file.endswith(".gz") assert idx1_file.endswith(".gz")
......
...@@ -7,7 +7,7 @@ import cv2 ...@@ -7,7 +7,7 @@ import cv2
import numpy as np import numpy as np
from megengine.data.transform import Transform from megengine.data.transform import Transform
from megengine.data.transform.vision import _functional as F from megengine.data.transform.vision import functional as F
__all__ = [ __all__ = [
"VisionTransform", "VisionTransform",
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from mprop import mproperty from mprop import mproperty
from ..core._imperative_rt.core2 import group_end, group_start from ..core._imperative_rt.core2 import group_end, group_start
from . import group
from .group import ( from .group import (
WORLD, WORLD,
Group, Group,
...@@ -19,7 +20,7 @@ from .group import ( ...@@ -19,7 +20,7 @@ from .group import (
) )
from .helper import bcast_list_, make_allreduce_cb, synchronized from .helper import bcast_list_, make_allreduce_cb, synchronized
from .launcher import launcher from .launcher import launcher
from .server import Server from .server import Client, Server
@mproperty @mproperty
......
...@@ -7,10 +7,10 @@ from mprop import mproperty ...@@ -7,10 +7,10 @@ from mprop import mproperty
from ..device import _sh, set_default_device, what_is_xpu from ..device import _sh, set_default_device, what_is_xpu
from ..random import seed from ..random import seed
from .server import Server, _Client from .server import Client, Server
class _StaticData: class StaticData:
server = None server = None
client = None client = None
master_ip = None master_ip = None
...@@ -139,13 +139,13 @@ def init_process_group( ...@@ -139,13 +139,13 @@ def init_process_group(
global _sd global _sd
assert _sd is None, "init_process_group should be called only once" assert _sd is None, "init_process_group should be called only once"
_sd = _StaticData() _sd = StaticData()
assert world_size > 1 assert world_size > 1
assert rank >= 0 and rank < world_size assert rank >= 0 and rank < world_size
assert port > 0 assert port > 0
_sd.client = _Client(master_ip, port) _sd.client = Client(master_ip, port)
_sd.master_ip = master_ip _sd.master_ip = master_ip
_sd.py_server_port = port _sd.py_server_port = port
_sd.mm_server_port = _sd.client.get_mm_server_port() _sd.mm_server_port = _sd.client.get_mm_server_port()
...@@ -225,7 +225,7 @@ def get_mm_server_addr() -> Tuple[str, int]: ...@@ -225,7 +225,7 @@ def get_mm_server_addr() -> Tuple[str, int]:
return _sd.master_ip, _sd.mm_server_port return _sd.master_ip, _sd.mm_server_port
def get_client() -> _Client: def get_client() -> Client:
r"""Get client of python XML RPC server.""" r"""Get client of python XML RPC server."""
assert _sd is not None, "please call init_process_group first" assert _sd is not None, "please call init_process_group first"
return _sd.client return _sd.client
......
...@@ -7,7 +7,7 @@ from weakref import WeakSet ...@@ -7,7 +7,7 @@ from weakref import WeakSet
import numpy as np import numpy as np
from megengine.autodiff.grad_manager import GradManager, _get_backwarding_grad_manager from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core.ops.builtin import ParamPackConcat, ParamPackSplit from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
...@@ -78,7 +78,7 @@ def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list): ...@@ -78,7 +78,7 @@ def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
return apply(op, *inps, offsets)[0] return apply(op, *inps, offsets)[0]
def _get_offsets(shapes): def get_offsets(shapes):
offsets = [] offsets = []
offset = 0 offset = 0
for shape in shapes: for shape in shapes:
...@@ -108,7 +108,7 @@ def _check_enable_p2p(): ...@@ -108,7 +108,7 @@ def _check_enable_p2p():
def pack_allreduce_split(pack_list, shapes, group, reduce_method): def pack_allreduce_split(pack_list, shapes, group, reduce_method):
offsets_val = _get_offsets(shapes) offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val) offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
...@@ -119,7 +119,7 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method): ...@@ -119,7 +119,7 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method):
return grads return grads
class _TensorFuture(Future): class TensorFuture(Future):
def device(self): def device(self):
raise "Sorry, this tensor is not ready" raise "Sorry, this tensor is not ready"
...@@ -234,13 +234,13 @@ class AllreduceCallback: ...@@ -234,13 +234,13 @@ class AllreduceCallback:
self._packing_size[dtype] = 0 self._packing_size[dtype] = 0
def __call__(self, param, grad): def __call__(self, param, grad):
gm = _get_backwarding_grad_manager() gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager) assert isinstance(gm, GradManager)
if gm not in self._marked_gm: if gm not in self._marked_gm:
gm._register_after_backward_callback(self._flush) gm._register_after_backward_callback(self._flush)
self._marked_gm.add(gm) self._marked_gm.add(gm)
self._params.append(param) self._params.append(param)
self._futures_dict[param] = _TensorFuture(ack=False) self._futures_dict[param] = TensorFuture(ack=False)
self._gradients_dict[param] = grad self._gradients_dict[param] = grad
self._grad_origin_device[param] = str(grad.device) self._grad_origin_device[param] = str(grad.device)
......
...@@ -10,7 +10,7 @@ from ..device import get_device_count ...@@ -10,7 +10,7 @@ from ..device import get_device_count
from ..logger import get_logger from ..logger import get_logger
from .group import _set_machine_ranks, group_barrier, init_process_group from .group import _set_machine_ranks, group_barrier, init_process_group
from .helper import _check_device_initialized, _check_interpreter_status from .helper import _check_device_initialized, _check_interpreter_status
from .server import Server from .server import Client, Server
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
"subprocess exited with code 0 but did not return a value" "subprocess exited with code 0 but did not return a value"
......
...@@ -12,7 +12,7 @@ from ..core._imperative_rt.utils import create_mm_server ...@@ -12,7 +12,7 @@ from ..core._imperative_rt.utils import create_mm_server
from ..utils.future import Future from ..utils.future import Future
class _Methods: class Methods:
r"""Distributed Server Method. r"""Distributed Server Method.
Used for exchange information between distributed nodes. Used for exchange information between distributed nodes.
...@@ -149,7 +149,7 @@ class _Methods: ...@@ -149,7 +149,7 @@ class _Methods:
return ret return ret
class _ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass pass
...@@ -163,10 +163,10 @@ def _start_server(py_server_port, queue): ...@@ -163,10 +163,10 @@ def _start_server(py_server_port, queue):
""" """
try: try:
mm_server_port = create_mm_server("0.0.0.0", 0) mm_server_port = create_mm_server("0.0.0.0", 0)
server = _ThreadXMLRPCServer( server = ThreadXMLRPCServer(
("0.0.0.0", py_server_port), logRequests=False, allow_none=True ("0.0.0.0", py_server_port), logRequests=False, allow_none=True
) )
server.register_instance(_Methods(mm_server_port)) server.register_instance(Methods(mm_server_port))
_, py_server_port = server.server_address _, py_server_port = server.server_address
queue.put((py_server_port, mm_server_port)) queue.put((py_server_port, mm_server_port))
server.serve_forever() server.serve_forever()
...@@ -196,7 +196,7 @@ class Server: ...@@ -196,7 +196,7 @@ class Server:
self.proc.terminate() self.proc.terminate()
class _Client: class Client:
r"""Distributed Client for distributed training. r"""Distributed Client for distributed training.
Args: Args:
...@@ -298,10 +298,10 @@ class _Client: ...@@ -298,10 +298,10 @@ class _Client:
return self.proxy.bcast_val(val, key, size) return self.proxy.bcast_val(val, key, size)
def _main(port=0, verbose=True): def main(port=0, verbose=True):
mm_server_port = create_mm_server("0.0.0.0", 0) mm_server_port = create_mm_server("0.0.0.0", 0)
server = _ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose) server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose)
server.register_instance(_Methods(mm_server_port)) server.register_instance(Methods(mm_server_port))
_, port = server.server_address _, port = server.server_address
print("serving on port", port) print("serving on port", port)
server.serve_forever() server.serve_forever()
...@@ -314,4 +314,4 @@ if __name__ == "__main__": ...@@ -314,4 +314,4 @@ if __name__ == "__main__":
ap.add_argument("-p", "--port", type=int, default=0) ap.add_argument("-p", "--port", type=int, default=0)
ap.add_argument("-v", "--verbose", type=bool, default=True) ap.add_argument("-v", "--verbose", type=bool, default=True)
args = ap.parse_args() args = ap.parse_args()
_main(port=args.port, verbose=args.verbose) main(port=args.port, verbose=args.verbose)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
from . import metric, utils, vision
from .elemwise import * from .elemwise import *
from .math import * from .math import *
from .nn import * from .nn import *
from .tensor import * from .tensor import *
from .utils import *
from . import utils, vision, distributed # isort:skip from . import distributed # isort:skip
# delete namespace # delete namespace
# pylint: disable=undefined-variable # pylint: disable=undefined-variable
......
...@@ -21,10 +21,6 @@ _valid_string_option = { ...@@ -21,10 +21,6 @@ _valid_string_option = {
} }
@deprecated(
version="1.10",
reason="use ``megengine.config.benchmark_kernel`` and ``megengine.config.deterministic_kernel`` instead",
)
def get_execution_strategy() -> Strategy: def get_execution_strategy() -> Strategy:
r"""Returns the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul` r"""Returns the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul`
...@@ -40,10 +36,6 @@ def get_execution_strategy() -> Strategy: ...@@ -40,10 +36,6 @@ def get_execution_strategy() -> Strategy:
return strategy return strategy
@deprecated(
version="1.10",
reason="use ``megengine.config.benchmark_kernel`` and ``megengine.config.deterministic_kernel`` instead",
)
def set_execution_strategy(option): def set_execution_strategy(option):
r"""Sets the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul` r"""Sets the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul`
......
...@@ -9,7 +9,7 @@ from ..core.tensor.array_method import _elwise ...@@ -9,7 +9,7 @@ from ..core.tensor.array_method import _elwise
from ..core.tensor.utils import convert_inputs from ..core.tensor.utils import convert_inputs
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_func from ..utils.deprecation import deprecated_func
from ._tensor_cache import get_scalar_one from .tensor_cache import get_scalar_one
__all__ = [ __all__ = [
"abs", "abs",
......
...@@ -43,7 +43,6 @@ from .debug_param import get_execution_strategy ...@@ -43,7 +43,6 @@ from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import _elwise, exp, log, log1p, maximum, minimum from .elemwise import _elwise, exp, log, log1p, maximum, minimum
from .math import max, normalize, sum from .math import max, normalize, sum
from .metric import topk_accuracy
from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros
__all__ = [ __all__ = [
...@@ -87,7 +86,6 @@ __all__ = [ ...@@ -87,7 +86,6 @@ __all__ = [
"softmax", "softmax",
"softplus", "softplus",
"sync_batch_norm", "sync_batch_norm",
"topk_accuracy",
"warp_affine", "warp_affine",
"warp_perspective", "warp_perspective",
"pixel_shuffle", "pixel_shuffle",
...@@ -95,7 +93,7 @@ __all__ = [ ...@@ -95,7 +93,7 @@ __all__ = [
] ]
def _expand_hw(x): def expand_hw(x):
# judge int is 5 times faster than judge Sequence # judge int is 5 times faster than judge Sequence
if isinstance(x, int): if isinstance(x, int):
return x, x return x, x
...@@ -104,7 +102,7 @@ def _expand_hw(x): ...@@ -104,7 +102,7 @@ def _expand_hw(x):
return int(x), int(x) return int(x), int(x)
def _expand_dhw(x): def expand_dhw(x):
if isinstance(x, int): if isinstance(x, int):
return x, x, x return x, x, x
if isinstance(x, Sequence): if isinstance(x, Sequence):
...@@ -246,9 +244,9 @@ def conv2d( ...@@ -246,9 +244,9 @@ def conv2d(
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )
stride_h, stride_w = _expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = _expand_hw(padding) pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = _expand_hw(dilation) dilate_h, dilate_w = expand_hw(dilation)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
...@@ -308,9 +306,9 @@ def conv3d( ...@@ -308,9 +306,9 @@ def conv3d(
D, H, W = 0, 1, 2 D, H, W = 0, 1, 2
pad = _expand_dhw(padding) pad = expand_dhw(padding)
stride = _expand_dhw(stride) stride = expand_dhw(stride)
dilate = _expand_dhw(dilation) dilate = expand_dhw(dilation)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3D( op = builtin.Convolution3D(
...@@ -378,10 +376,10 @@ def conv_transpose2d( ...@@ -378,10 +376,10 @@ def conv_transpose2d(
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )
stride_h, stride_w = _expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = _expand_hw(padding) pad_h, pad_w = expand_hw(padding)
output_pad_h, output_pad_w = _expand_hw(output_padding) output_pad_h, output_pad_w = expand_hw(output_padding)
dilate_h, dilate_w = _expand_hw(dilation) dilate_h, dilate_w = expand_hw(dilation)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
...@@ -479,9 +477,9 @@ def deformable_conv2d( ...@@ -479,9 +477,9 @@ def deformable_conv2d(
offset = offset.astype("float32") offset = offset.astype("float32")
mask = mask.astype("float32") mask = mask.astype("float32")
stride_h, stride_w = _expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = _expand_hw(padding) pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = _expand_hw(dilation) dilate_h, dilate_w = expand_hw(dilation)
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
...@@ -533,9 +531,9 @@ def local_conv2d( ...@@ -533,9 +531,9 @@ def local_conv2d(
or conv_mode.name == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
) )
stride_h, stride_w = _expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = _expand_hw(padding) pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = _expand_hw(dilation) dilate_h, dilate_w = expand_hw(dilation)
# local conv only support "dense" mode, but weight could contain group dimension. # local conv only support "dense" mode, but weight could contain group dimension.
op = builtin.GroupLocal( op = builtin.GroupLocal(
...@@ -589,10 +587,10 @@ def conv_transpose3d( ...@@ -589,10 +587,10 @@ def conv_transpose3d(
output tensor. output tensor.
""" """
D, H, W = 0, 1, 2 D, H, W = 0, 1, 2
pad = _expand_dhw(padding) pad = expand_dhw(padding)
stride = _expand_dhw(stride) stride = expand_dhw(stride)
dilate = _expand_dhw(dilation) dilate = expand_dhw(dilation)
output_padding = _expand_dhw(output_padding) output_padding = expand_dhw(output_padding)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3DBackwardData( op = builtin.Convolution3DBackwardData(
...@@ -677,9 +675,9 @@ def max_pool2d( ...@@ -677,9 +675,9 @@ def max_pool2d(
""" """
if stride is None: if stride is None:
stride = kernel_size stride = kernel_size
window_h, window_w = _expand_hw(kernel_size) window_h, window_w = expand_hw(kernel_size)
stride_h, stride_w = _expand_hw(stride) stride_h, stride_w = expand_hw(stride)
padding_h, padding_w = _expand_hw(padding) padding_h, padding_w = expand_hw(padding)
op = builtin.Pooling( op = builtin.Pooling(
window_h=window_h, window_h=window_h,
...@@ -727,9 +725,9 @@ def avg_pool2d( ...@@ -727,9 +725,9 @@ def avg_pool2d(
""" """
if stride is None: if stride is None:
stride = kernel_size stride = kernel_size
window_h, window_w = _expand_hw(kernel_size) window_h, window_w = expand_hw(kernel_size)
stride_h, stride_w = _expand_hw(stride) stride_h, stride_w = expand_hw(stride)
padding_h, padding_w = _expand_hw(padding) padding_h, padding_w = expand_hw(padding)
op = builtin.Pooling( op = builtin.Pooling(
window_h=window_h, window_h=window_h,
...@@ -1725,10 +1723,10 @@ def sliding_window( ...@@ -1725,10 +1723,10 @@ def sliding_window(
stride: stride of the window. Default: 1 stride: stride of the window. Default: 1
dilation: dilation of the window. Default: 1 dilation: dilation of the window. Default: 1
""" """
padding_h, padding_w = _expand_hw(padding) padding_h, padding_w = expand_hw(padding)
stride_h, stride_w = _expand_hw(stride) stride_h, stride_w = expand_hw(stride)
dilation_h, dilation_w = _expand_hw(dilation) dilation_h, dilation_w = expand_hw(dilation)
window_h, window_w = _expand_hw(kernel_size) window_h, window_w = expand_hw(kernel_size)
op = builtin.Images2Neibs( op = builtin.Images2Neibs(
pad_h=padding_h, pad_h=padding_h,
...@@ -1764,11 +1762,11 @@ def sliding_window_transpose( ...@@ -1764,11 +1762,11 @@ def sliding_window_transpose(
stride: stride of the window. Default: 1 stride: stride of the window. Default: 1
dilation: dilation of the window. Default: 1 dilation: dilation of the window. Default: 1
""" """
output_h, output_w = _expand_hw(output_size) output_h, output_w = expand_hw(output_size)
padding_h, padding_w = _expand_hw(padding) padding_h, padding_w = expand_hw(padding)
stride_h, stride_w = _expand_hw(stride) stride_h, stride_w = expand_hw(stride)
dilation_h, dilation_w = _expand_hw(dilation) dilation_h, dilation_w = expand_hw(dilation)
window_h, window_w = _expand_hw(kernel_size) window_h, window_w = expand_hw(kernel_size)
expected_h = ( expected_h = (
output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1 output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1
...@@ -1921,7 +1919,7 @@ def _get_layerPixelShuffle(device, dtype, dim_order): ...@@ -1921,7 +1919,7 @@ def _get_layerPixelShuffle(device, dtype, dim_order):
return layerPixelShuffle return layerPixelShuffle
def _layerPixelShuffle_traceable(inp, upscale_factor): def layerPixelShuffle_traceable(inp, upscale_factor):
assert upscale_factor > 0, "upscale_factor should larger than 0" assert upscale_factor > 0, "upscale_factor should larger than 0"
assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3" assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3"
assert ( assert (
...@@ -1972,7 +1970,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor: ...@@ -1972,7 +1970,7 @@ def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
:param upscale_factor: upscale factor of pixel_shuffle. :param upscale_factor: upscale factor of pixel_shuffle.
:return: output tensor. :return: output tensor.
""" """
return pixel_shuffle_cpp(inp, upscale_factor, _layerPixelShuffle_traceable) return pixel_shuffle_cpp(inp, upscale_factor, layerPixelShuffle_traceable)
def region_restricted_conv( def region_restricted_conv(
...@@ -2014,9 +2012,9 @@ def region_restricted_conv( ...@@ -2014,9 +2012,9 @@ def region_restricted_conv(
""" """
assert conv_mode.lower() == "cross_correlation" assert conv_mode.lower() == "cross_correlation"
pad_h, pad_w = _expand_hw(padding) pad_h, pad_w = expand_hw(padding)
stride_h, stride_w = _expand_hw(stride) stride_h, stride_w = expand_hw(stride)
dilate_h, dilate_w = _expand_hw(dilation) dilate_h, dilate_w = expand_hw(dilation)
sparse_type = "dense" if groups == 1 else "group" sparse_type = "dense" if groups == 1 else "group"
op = builtin.RegionRestrictedConvolution( op = builtin.RegionRestrictedConvolution(
...@@ -2038,4 +2036,5 @@ def region_restricted_conv( ...@@ -2038,4 +2036,5 @@ def region_restricted_conv(
from .quantized import conv_bias_activation # isort:skip from .quantized import conv_bias_activation # isort:skip
from .loss import * # isort:skip from .loss import * # isort:skip
from .metric import * # isort:skip
from .vision import * # isort:skip from .vision import * # isort:skip
from ..core._imperative_rt.core2 import Const from ..core._imperative_rt.core2 import Const
from ..jit.tracing import _is_tracing from ..jit.tracing import is_tracing
small_tensor_cache = {} small_tensor_cache = {}
def _get_scalar_tensor_with_value(value, dtype=None, device=None): def _get_scalar_tensor_with_value(value, dtype=None, device=None):
global small_tensor_cache global small_tensor_cache
if _is_tracing(): if is_tracing():
ret = Const(value, dtype, device) ret = Const(value, dtype, device)
else: else:
cache_key = (value, dtype, device) cache_key = (value, dtype, device)
......
...@@ -7,6 +7,8 @@ from ..utils.deprecation import deprecated_func ...@@ -7,6 +7,8 @@ from ..utils.deprecation import deprecated_func
from .elemwise import abs, maximum, minimum from .elemwise import abs, maximum, minimum
from .tensor import ones, zeros from .tensor import ones, zeros
__all__ = ["topk_accuracy"]
def _assert_equal( def _assert_equal(
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False
......
...@@ -36,7 +36,7 @@ pattern = re.compile( ...@@ -36,7 +36,7 @@ pattern = re.compile(
) )
class _RepoFetcherBase: class RepoFetcherBase:
@classmethod @classmethod
def fetch( def fetch(
cls, cls,
...@@ -84,7 +84,7 @@ class _RepoFetcherBase: ...@@ -84,7 +84,7 @@ class _RepoFetcherBase:
return hashlib.sha1(repo_dir.encode()).hexdigest()[:16] return hashlib.sha1(repo_dir.encode()).hexdigest()[:16]
class GitSSHFetcher(_RepoFetcherBase): class GitSSHFetcher(RepoFetcherBase):
@classmethod @classmethod
@synchronized @synchronized
def fetch( def fetch(
...@@ -193,7 +193,7 @@ class GitSSHFetcher(_RepoFetcherBase): ...@@ -193,7 +193,7 @@ class GitSSHFetcher(_RepoFetcherBase):
) )
class GitHTTPSFetcher(_RepoFetcherBase): class GitHTTPSFetcher(RepoFetcherBase):
@classmethod @classmethod
@synchronized @synchronized
def fetch( def fetch(
......
...@@ -49,7 +49,7 @@ active_trace = None ...@@ -49,7 +49,7 @@ active_trace = None
skip_tracing = False skip_tracing = False
def _is_tracing(): def is_tracing():
if active_trace is None: if active_trace is None:
return False return False
else: else:
...@@ -73,7 +73,7 @@ def exclude_from_trace(): ...@@ -73,7 +73,7 @@ def exclude_from_trace():
skip_tracing = False skip_tracing = False
def _array_comparator(lhs, rhs): def array_comparator(lhs, rhs):
return np.all(lhs == rhs) return np.all(lhs == rhs)
...@@ -184,7 +184,7 @@ class trace: ...@@ -184,7 +184,7 @@ class trace:
self._trace.no_exec = record_only self._trace.no_exec = record_only
self._trace.options_visitor = apply_options self._trace.options_visitor = apply_options
self._trace.profile = profiling self._trace.profile = profiling
self._trace.array_comparator = _array_comparator self._trace.array_comparator = array_comparator
self._trace.record_input_shapes = _input_node_use_static_shape() self._trace.record_input_shapes = _input_node_use_static_shape()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
......
...@@ -18,10 +18,10 @@ def set_log_file(fout, mode="a"): ...@@ -18,10 +18,10 @@ def set_log_file(fout, mode="a"):
""" """
if isinstance(fout, str): if isinstance(fout, str):
fout = open(fout, mode) fout = open(fout, mode)
_MegEngineLogFormatter.log_fout = fout MegEngineLogFormatter.log_fout = fout
class _MegEngineLogFormatter(logging.Formatter): class MegEngineLogFormatter(logging.Formatter):
log_fout = None log_fout = None
date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] " date_full = "[%(asctime)s %(lineno)d@%(filename)s:%(name)s] "
date = "%(asctime)s " date = "%(asctime)s "
...@@ -71,7 +71,7 @@ class _MegEngineLogFormatter(logging.Formatter): ...@@ -71,7 +71,7 @@ class _MegEngineLogFormatter(logging.Formatter):
if self.log_fout: if self.log_fout:
self.__set_fmt(self.date_full + mtxt + self.msg) self.__set_fmt(self.date_full + mtxt + self.msg)
formatted = super(_MegEngineLogFormatter, self).format(record) formatted = super(MegEngineLogFormatter, self).format(record)
nr_line = formatted.count("\n") + 1 nr_line = formatted.count("\n") + 1
if nr_line >= self.max_lines: if nr_line >= self.max_lines:
head, body = formatted.split("\n", 1) head, body = formatted.split("\n", 1)
...@@ -88,7 +88,7 @@ class _MegEngineLogFormatter(logging.Formatter): ...@@ -88,7 +88,7 @@ class _MegEngineLogFormatter(logging.Formatter):
self.log_fout.flush() self.log_fout.flush()
self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg))
formatted = super(_MegEngineLogFormatter, self).format(record) formatted = super(MegEngineLogFormatter, self).format(record)
if record.exc_text or record.exc_info: if record.exc_text or record.exc_info:
# handle exception format # handle exception format
...@@ -125,7 +125,7 @@ class _MegEngineLogFormatter(logging.Formatter): ...@@ -125,7 +125,7 @@ class _MegEngineLogFormatter(logging.Formatter):
self._style._fmt = fmt self._style._fmt = fmt
def get_logger(name=None, formatter=_MegEngineLogFormatter): def get_logger(name=None, formatter=MegEngineLogFormatter):
r"""Gets megengine logger with given name.""" r"""Gets megengine logger with given name."""
logger = logging.getLogger(name) logger = logging.getLogger(name)
...@@ -167,16 +167,16 @@ try: ...@@ -167,16 +167,16 @@ try:
from .core._imperative_rt.utils import Logger as _imperative_rt_logger from .core._imperative_rt.utils import Logger as _imperative_rt_logger
class _MegBrainLogFormatter(_MegEngineLogFormatter): class MegBrainLogFormatter(MegEngineLogFormatter):
date = "%(asctime)s[mgb] " date = "%(asctime)s[mgb] "
def _color_date(self, msg): def _color_date(self, msg):
return "\x1b[33m{}\x1b[0m".format(msg) return "\x1b[33m{}\x1b[0m".format(msg)
_megbrain_logger = get_logger("megbrain", _MegBrainLogFormatter) _megbrain_logger = get_logger("megbrain", MegBrainLogFormatter)
_imperative_rt_logger.set_log_handler(_megbrain_logger) _imperative_rt_logger.set_log_handler(_megbrain_logger)
def _set_mgb_log_level(level): def set_mgb_log_level(level):
r"""Sets megbrain log level r"""Sets megbrain log level
Args: Args:
...@@ -200,30 +200,30 @@ try: ...@@ -200,30 +200,30 @@ try:
) )
return rst return rst
_set_mgb_log_level(_default_level) set_mgb_log_level(_default_level)
except ImportError as exc: except ImportError as exc:
def _set_mgb_log_level(level): def set_mgb_log_level(level):
raise NotImplementedError("imperative_rt has not been imported") raise NotImplementedError("imperative_rt has not been imported")
@contextlib.contextmanager @contextlib.contextmanager
def _replace_mgb_log_level(level): def replace_mgb_log_level(level):
r"""Replaces megbrain log level in a block and restore after exiting. r"""Replaces megbrain log level in a block and restore after exiting.
Args: Args:
level: new log level level: new log level
""" """
old = _set_mgb_log_level(level) old = set_mgb_log_level(level)
try: try:
yield yield
finally: finally:
_set_mgb_log_level(old) set_mgb_log_level(old)
def enable_debug_log(): def enable_debug_log():
r"""Sets logging level to debug for all components.""" r"""Sets logging level to debug for all components."""
set_log_level(logging.DEBUG) set_log_level(logging.DEBUG)
_set_mgb_log_level(logging.DEBUG) set_mgb_log_level(logging.DEBUG)
...@@ -15,12 +15,12 @@ from . import init ...@@ -15,12 +15,12 @@ from . import init
from .module import Module from .module import Module
class _RNNCellBase(Module): class RNNCellBase(Module):
def __init__( def __init__(
self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
) -> None: ) -> None:
# num_chunks indicates the number of gates # num_chunks indicates the number of gates
super(_RNNCellBase, self).__init__() super(RNNCellBase, self).__init__()
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -57,7 +57,7 @@ class _RNNCellBase(Module): ...@@ -57,7 +57,7 @@ class _RNNCellBase(Module):
raise NotImplementedError("forward not implemented !") raise NotImplementedError("forward not implemented !")
class RNNCell(_RNNCellBase): class RNNCell(RNNCellBase):
r"""An Elman RNN cell with tanh or ReLU non-linearity. r"""An Elman RNN cell with tanh or ReLU non-linearity.
...@@ -135,7 +135,7 @@ class RNNCell(_RNNCellBase): ...@@ -135,7 +135,7 @@ class RNNCell(_RNNCellBase):
)[0] )[0]
class LSTMCell(_RNNCellBase): class LSTMCell(RNNCellBase):
r"""A long short-term memory (LSTM) cell. r"""A long short-term memory (LSTM) cell.
...@@ -216,7 +216,7 @@ class LSTMCell(_RNNCellBase): ...@@ -216,7 +216,7 @@ class LSTMCell(_RNNCellBase):
)[:2] )[:2]
class _RNNBase(Module): class RNNBase(Module):
def __init__( def __init__(
self, self,
input_size: int, input_size: int,
...@@ -228,7 +228,7 @@ class _RNNBase(Module): ...@@ -228,7 +228,7 @@ class _RNNBase(Module):
bidirectional: bool = False, bidirectional: bool = False,
proj_size: int = 0, proj_size: int = 0,
) -> None: ) -> None:
super(_RNNBase, self).__init__() super(RNNBase, self).__init__()
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_layers = num_layers self.num_layers = num_layers
...@@ -323,7 +323,7 @@ class _RNNBase(Module): ...@@ -323,7 +323,7 @@ class _RNNBase(Module):
return output, h return output, h
class RNN(_RNNBase): class RNN(RNNBase):
r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
input sequence. input sequence.
...@@ -453,7 +453,7 @@ class RNN(_RNNBase): ...@@ -453,7 +453,7 @@ class RNN(_RNNBase):
return output, h return output, h
class LSTM(_RNNBase): class LSTM(RNNBase):
r"""Applies a multi-layer long short-term memory LSTM to an input r"""Applies a multi-layer long short-term memory LSTM to an input
sequence. sequence.
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
from .. import functional as F from .. import functional as F
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, get_rank, is_distributed
from ..functional.distributed import all_reduce_max, all_reduce_min from ..functional.distributed import all_reduce_max, all_reduce_min
from ..logger import get_logger from ..logger import get_logger
from ..module import Module from ..module import Module
......
...@@ -66,7 +66,7 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.DEFAULT_PROTOCOL): ...@@ -66,7 +66,7 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.DEFAULT_PROTOCOL):
pickle_module.dump(obj, f, pickle_protocol) pickle_module.dump(obj, f, pickle_protocol)
class _dmap: class dmap:
def __init__(self, map_location): def __init__(self, map_location):
self.map_location = map_location self.map_location = map_location
...@@ -177,5 +177,5 @@ def load(f, map_location=None, pickle_module=pickle): ...@@ -177,5 +177,5 @@ def load(f, map_location=None, pickle_module=pickle):
map_location = _get_callable_map_location(map_location) # callable map_location map_location = _get_callable_map_location(map_location) # callable map_location
with _dmap(map_location) as dm: with dmap(map_location) as dm:
return pickle_module.load(f) return pickle_module.load(f)
...@@ -33,7 +33,6 @@ def deprecated_func(version, origin, name, tbd): ...@@ -33,7 +33,6 @@ def deprecated_func(version, origin, name, tbd):
) )
return func(*args, **kwargs) return func(*args, **kwargs)
wrapper.__deprecated__ = True
return wrapper return wrapper
...@@ -58,7 +57,6 @@ def deprecated_kwargs_default(version, kwargs_name, kwargs_pos): ...@@ -58,7 +57,6 @@ def deprecated_kwargs_default(version, kwargs_name, kwargs_pos):
) )
return func(*args, **kwargs) return func(*args, **kwargs)
wrapper.__deprecated__ = True
return wrapper return wrapper
return deprecated return deprecated
...@@ -11,11 +11,11 @@ from .. import functional as F ...@@ -11,11 +11,11 @@ from .. import functional as F
from .. import get_logger from .. import get_logger
from .. import module as M from .. import module as M
from ..core.tensor.dtype import get_dtype_bit from ..core.tensor.dtype import get_dtype_bit
from ..logger import _MegEngineLogFormatter from ..logger import MegEngineLogFormatter
from .module_utils import set_module_mode_safe from .module_utils import set_module_mode_safe
try: try:
_MegEngineLogFormatter.max_lines = float("inf") MegEngineLogFormatter.max_lines = float("inf")
except AttributeError as e: except AttributeError as e:
raise ValueError("set logger max lines failed") raise ValueError("set logger max lines failed")
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
import logging import logging
from megengine.core._imperative_rt import Logger from megengine.core._imperative_rt import Logger
from megengine.logger import _imperative_rt_logger, _set_mgb_log_level from megengine.logger import _imperative_rt_logger, set_mgb_log_level
def test_logger(): def test_logger():
orig_level = Logger().set_log_level(Logger.LogLevel.Debug) orig_level = Logger().set_log_level(Logger.LogLevel.Debug)
assert Logger().set_log_level(Logger.LogLevel.Debug) == Logger.LogLevel.Debug assert Logger().set_log_level(Logger.LogLevel.Debug) == Logger.LogLevel.Debug
Logger().set_log_level(orig_level) Logger().set_log_level(orig_level)
orig_level = _set_mgb_log_level(logging.DEBUG) orig_level = set_mgb_log_level(logging.DEBUG)
assert ( assert (
_imperative_rt_logger.set_log_level(Logger.LogLevel.Debug) _imperative_rt_logger.set_log_level(Logger.LogLevel.Debug)
== Logger.LogLevel.Debug == Logger.LogLevel.Debug
......
...@@ -50,7 +50,7 @@ def test_init_process_group(backend): ...@@ -50,7 +50,7 @@ def test_init_process_group(backend):
assert mm_server_addr[0] == "localhost" assert mm_server_addr[0] == "localhost"
assert mm_server_addr[1] > 0 assert mm_server_addr[1] > 0
assert isinstance(dist.get_client(), dist.server._Client) assert isinstance(dist.get_client(), dist.Client)
procs = [] procs = []
for rank in range(world_size): for rank in range(world_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册