提交 026af620 编写于 作者: M Megvii Engine Team

docs(mge): docs typo fix

GitOrigin-RevId: 851f6de02f308691ed4f27bd3801af022911b9d1
上级 10106341
...@@ -20,7 +20,7 @@ class GradManager: ...@@ -20,7 +20,7 @@ class GradManager:
the forward operations start and when all resources should be released. A typical usage of the forward operations start and when all resources should be released. A typical usage of
GradManager is as follows: GradManager is as follows:
.. codeblock:: .. code-block::
gm = GradManager() gm = GradManager()
gm.attach(model.parameters()) gm.attach(model.parameters())
...@@ -32,7 +32,7 @@ class GradManager: ...@@ -32,7 +32,7 @@ class GradManager:
You can also use `record()` and `release()` method instead of `with` context: You can also use `record()` and `release()` method instead of `with` context:
.. codeblock:: .. code-block::
gm = GradManager() gm = GradManager()
gm.attach(model.parameters()) gm.attach(model.parameters())
...@@ -50,7 +50,7 @@ class GradManager: ...@@ -50,7 +50,7 @@ class GradManager:
processes. Users will finally get the averaged gradients if an "AllReduce" processes. Users will finally get the averaged gradients if an "AllReduce"
callback is registered as follows: callback is registered as follows:
.. codeblock:: .. code-block::
import megengine.distributed as dist import megengine.distributed as dist
...@@ -71,7 +71,7 @@ class GradManager: ...@@ -71,7 +71,7 @@ class GradManager:
r"""Registers parameters that gradients should be calculated with respect to. r"""Registers parameters that gradients should be calculated with respect to.
Callback Functions should have a signature like this: Callback Functions should have a signature like this:
.. codeblock:: .. code-block::
def cb(param: Tensor, grad: Tensor) -> Tensor: def cb(param: Tensor, grad: Tensor) -> Tensor:
# do something # do something
......
...@@ -50,8 +50,8 @@ class Function: ...@@ -50,8 +50,8 @@ class Function:
""" """
Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
:param input: Input tensors. :param input: input tensors.
:return: A tuple of Tensor or a single Tensor. :return: a tuple of Tensor or a single Tensor.
.. note:: .. note::
...@@ -64,7 +64,7 @@ class Function: ...@@ -64,7 +64,7 @@ class Function:
""" """
Compute the gradient of the forward function. It must be overriden by all subclasses. Compute the gradient of the forward function. It must be overriden by all subclasses.
:param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward` :param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward`.
.. note:: .. note::
......
...@@ -34,14 +34,14 @@ default_collate_err_msg_format = ( ...@@ -34,14 +34,14 @@ default_collate_err_msg_format = (
class Collator: class Collator:
r""" r"""
Used for merge a list of samples to form a mini-batch of Tenor(s). Used when using batched loading from a dataset. Used for merging a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a dataset.
modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py Modified from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
""" """
def apply(self, inputs): def apply(self, inputs):
""" """
input : sequence_N(tuple(CHW, C, CK)) :param input: sequence_N(tuple(CHW, C, CK)).
output : tuple(NCHW, NC, NCK) :return: tuple(NCHW, NC, NCK).
""" """
elem = inputs[0] elem = inputs[0]
elem_type = type(elem) elem_type = type(elem)
......
...@@ -43,7 +43,7 @@ class DataLoader: ...@@ -43,7 +43,7 @@ class DataLoader:
): ):
r"""Provides a convenient way to iterate on a given dataset. r"""Provides a convenient way to iterate on a given dataset.
`DataLoader` combines a dataset with sampler, transform and collator, `DataLoader` combines a dataset with `sampler`, `transform` and `collator`,
make it flexible to get minibatch continually from a dataset. make it flexible to get minibatch continually from a dataset.
:type dataset: Dataset :type dataset: Dataset
...@@ -53,21 +53,21 @@ class DataLoader: ...@@ -53,21 +53,21 @@ class DataLoader:
If specified, :attr:`shuffle` must be ``False``. If specified, :attr:`shuffle` must be ``False``.
:type transform: Transform :type transform: Transform
:param transform: defined the transforming strategy for a sampled batch. :param transform: defined the transforming strategy for a sampled batch.
(default: ``None``) Default: None
:type collator: Collator :type collator: Collator
:param collator: defined the merging strategy for a transformed batch. :param collator: defined the merging strategy for a transformed batch.
(default: ``None``) Default: None
:type num_workers: int :type num_workers: int
:param num_workers: the number of sub-process to load, transform and collate :param num_workers: the number of sub-process to load, transform and collate
the batch. ``0`` means using single-process. (default: ``0``) the batch. ``0`` means using single-process. Default: 0
:type timeout: int :type timeout: int
:param timeout: if positive, means the timeout value(second) for collecting a :param timeout: if positive, means the timeout value(second) for collecting a
batch from workers. (default: 0) batch from workers. Default: 0
:type divide: bool :type divide: bool
:param divide: define the paralleling strategy in multi-processing mode. :param divide: define the paralleling strategy in multi-processing mode.
``True`` means one batch is divided into :attr:`num_workers` pieces, and ``True`` means one batch is divided into :attr:`num_workers` pieces, and
the workers will process these pieces parallelly. ``False`` means the workers will process these pieces parallelly. ``False`` means
different sub-process will process different batch. (default: ``False``) different sub-process will process different batch. Default: False
""" """
......
...@@ -12,7 +12,7 @@ from typing import Tuple ...@@ -12,7 +12,7 @@ from typing import Tuple
class Dataset(ABC): class Dataset(ABC):
r""" r"""
An abstract class for all Datasets An abstract class for all Datasets.
""" """
@abstractmethod @abstractmethod
...@@ -22,8 +22,8 @@ class Dataset(ABC): ...@@ -22,8 +22,8 @@ class Dataset(ABC):
class MapDataset(Dataset): class MapDataset(Dataset):
r""" r"""
An abstract class for map data An abstract class for map data.
__getitem__ and __len__ method are aditionally needed __getitem__ and __len__ method are aditionally needed.
""" """
@abstractmethod @abstractmethod
...@@ -41,8 +41,8 @@ class MapDataset(Dataset): ...@@ -41,8 +41,8 @@ class MapDataset(Dataset):
class StreamDataset(Dataset): class StreamDataset(Dataset):
r""" r"""
An abstract class for stream data An abstract class for stream data.
__iter__ method is aditionally needed __iter__ method is aditionally needed.
""" """
@abstractmethod @abstractmethod
......
...@@ -21,7 +21,7 @@ logger = get_logger(__name__) ...@@ -21,7 +21,7 @@ logger = get_logger(__name__)
class CIFAR10(VisionDataset): class CIFAR10(VisionDataset):
r""" ``Dataset`` for CIFAR10 meta data r""" ``Dataset`` for CIFAR10 meta data.
""" """
url_path = "http://www.cs.utoronto.ca/~kriz/" url_path = "http://www.cs.utoronto.ca/~kriz/"
......
...@@ -30,19 +30,18 @@ class ImageFolder(VisionDataset): ...@@ -30,19 +30,18 @@ class ImageFolder(VisionDataset):
r""" r"""
ImageFolder is a class for loading image data and labels from a organized folder. ImageFolder is a class for loading image data and labels from a organized folder.
the folder is expected to be organized as followed The folder is expected to be organized as followed: root/cls/xxx.img_ext
root/cls/xxx.img_ext
labels are indices of sorted classes in the root directory Labels are indices of sorted classes in the root directory.
:param root: root directory of an image folder :param root: root directory of an image folder.
:param loader: a function used to load image from path, :param loader: a function used to load image from path,
if ``None``, default function that loads if ``None``, default function that loads
images with PILwill be called images with PIL will be called.
:param check_valid_func: a function used to check if files in folder are :param check_valid_func: a function used to check if files in folder are
expected image files, if ``None``, default function expected image files, if ``None``, default function
that checks file extensions will be called that checks file extensions will be called.
:param class_name: if ``True``, return class name instead of class index :param class_name: if ``True``, return class name instead of class index.
""" """
super().__init__(root, order=("image", "image_category")) super().__init__(root, order=("image", "image_category"))
......
...@@ -31,7 +31,7 @@ logger = get_logger(__name__) ...@@ -31,7 +31,7 @@ logger = get_logger(__name__)
class ImageNet(ImageFolder): class ImageNet(ImageFolder):
r""" r"""
Load ImageNet from raw files or folder, expected folder looks like Load ImageNet from raw files or folder. Expected folder looks like:
.. code-block:: bash .. code-block:: bash
...@@ -60,25 +60,25 @@ class ImageNet(ImageFolder): ...@@ -60,25 +60,25 @@ class ImageNet(ImageFolder):
def __init__(self, root: str = None, train: bool = True, **kwargs): def __init__(self, root: str = None, train: bool = True, **kwargs):
r""" r"""
initialization: Initialization:
* if ``root`` contains ``self.target_folder`` depent on ``train``: * if ``root`` contains ``self.target_folder`` depending on ``train``:
* initialize ImageFolder with target_folder * initialize ImageFolder with target_folder.
* else: * else:
* if all raw files are in ``root``: * if all raw files are in ``root``:
* parse ``self.target_folder`` from raw files * parse ``self.target_folder`` from raw files.
* initialize ImageFolder with ``self.target_folder`` * initialize ImageFolder with ``self.target_folder``.
* else: * else:
* raise error * raise error.
:param root: root directory of imagenet data, if root is ``None``, used default_dataset_root :param root: root directory of imagenet data, if root is ``None``, use default_dataset_root.
:param train: if ``True``, load the train split, otherwise load the validation split :param train: if ``True``, load the train split, otherwise load the validation split.
""" """
# process the root path # process the root path
......
...@@ -22,12 +22,12 @@ logger = get_logger(__name__) ...@@ -22,12 +22,12 @@ logger = get_logger(__name__)
class MNIST(VisionDataset): class MNIST(VisionDataset):
r""" ``Dataset`` for MNIST meta data r""" ``Dataset`` for MNIST meta data.
""" """
url_path = "http://yann.lecun.com/exdb/mnist/" url_path = "http://yann.lecun.com/exdb/mnist/"
""" """
url prefix for downloading raw file Url prefix for downloading raw file.
""" """
raw_file_name = [ raw_file_name = [
"train-images-idx3-ubyte.gz", "train-images-idx3-ubyte.gz",
...@@ -36,7 +36,7 @@ class MNIST(VisionDataset): ...@@ -36,7 +36,7 @@ class MNIST(VisionDataset):
"t10k-labels-idx1-ubyte.gz", "t10k-labels-idx1-ubyte.gz",
] ]
""" """
raw file names of both training set and test set (10k) Raw file names of both training set and test set (10k).
""" """
raw_file_md5 = [ raw_file_md5 = [
"f68b3c2dcbeaaa9fbdd348bbdeb94873", "f68b3c2dcbeaaa9fbdd348bbdeb94873",
...@@ -45,7 +45,7 @@ class MNIST(VisionDataset): ...@@ -45,7 +45,7 @@ class MNIST(VisionDataset):
"ec29112dd5afa0611ce80d1b7f02629c", "ec29112dd5afa0611ce80d1b7f02629c",
] ]
""" """
md5 for checking raw files Md5 for checking raw files.
""" """
def __init__( def __init__(
...@@ -57,10 +57,10 @@ class MNIST(VisionDataset): ...@@ -57,10 +57,10 @@ class MNIST(VisionDataset):
): ):
r""" r"""
:param root: path for mnist dataset downloading or loading, if ``None``, :param root: path for mnist dataset downloading or loading, if ``None``,
set ``root`` to the ``_default_root`` set ``root`` to the ``_default_root``.
:param train: if ``True``, loading trainingset, else loading test set :param train: if ``True``, loading trainingset, else loading test set.
:param download: if raw files do not exists and download sets to ``True``, :param download: if raw files do not exists and download sets to ``True``,
download raw files and process, otherwise raise ValueError, default is True download raw files and process, otherwise raise ValueError, default is True.
""" """
super().__init__(root, order=("image", "image_category")) super().__init__(root, order=("image", "image_category"))
......
...@@ -28,25 +28,25 @@ class Sampler(ABC): ...@@ -28,25 +28,25 @@ class Sampler(ABC):
seed=None, seed=None,
): ):
r""" r"""
An abstract class for all sampler An abstract class for all sampler.
:type dataset: `dataset` :type dataset: `dataset`
:param dataset: dataset to sample from :param dataset: dataset to sample from.
:type batch_size: positive integer :type batch_size: positive integer
:param batch_size: batch size for batch method :param batch_size: batch size for batch method.
:type drop_last: bool :type drop_last: bool
:param drop_last: set ``True`` to drop the last incomplete batch, :param drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. (default: ``False``) be smaller. Default: False
:type num_samples: positive integer :type num_samples: positive integer
:param num_samples: number of samples assigned to one rank :param num_samples: number of samples assigned to one rank.
:type world_size: positive integer :type world_size: positive integer
:param world_size: number of ranks :param world_size: number of ranks.
:type rank: non-negative integer within 0 and world_size :type rank: non-negative integer within 0 and world_size
:param rank: rank id, non-negative interger within 0 and ``world_size`` :param rank: rank id, non-negative interger within 0 and ``world_size``.
:type seed: non-negative integer :type seed: non-negative integer
:param seed: seed for random operators :param seed: seed for random operators.
""" """
if ( if (
not isinstance(batch_size, int) not isinstance(batch_size, int)
...@@ -103,15 +103,15 @@ class Sampler(ABC): ...@@ -103,15 +103,15 @@ class Sampler(ABC):
def sample(self): def sample(self):
""" """
return a list contains all sample indices Return a list contains all sample indices.
""" """
raise NotImplementedError raise NotImplementedError
def scatter(self, indices) -> List: def scatter(self, indices) -> List:
r""" r"""
scatter method is used for splitting indices into subset, each subset Scatter method is used for splitting indices into subset, each subset
will be assigned to a rank. Indices are evenly splitted by default. will be assigned to a rank. Indices are evenly splitted by default.
If customized indices assignment method is needed, please rewrite this method If customized indices assignment method is needed, please rewrite this method.
""" """
total_size = self.num_samples * self.world_size total_size = self.num_samples * self.world_size
...@@ -127,7 +127,7 @@ class Sampler(ABC): ...@@ -127,7 +127,7 @@ class Sampler(ABC):
def batch(self) -> Iterator[List[Any]]: def batch(self) -> Iterator[List[Any]]:
r""" r"""
batch method provides a batch indices generator Batch method provides a batch indices generator.
""" """
indices = list(self.sample()) indices = list(self.sample())
...@@ -156,7 +156,7 @@ class SequentialSampler(Sampler): ...@@ -156,7 +156,7 @@ class SequentialSampler(Sampler):
rank=None, rank=None,
): ):
r""" r"""
Sample elements sequentially Sample elements sequentially.
""" """
super().__init__(dataset, batch_size, drop_last, None, world_size, rank) super().__init__(dataset, batch_size, drop_last, None, world_size, rank)
if indices is not None and not isinstance(indices, collections.abc.Sequence): if indices is not None and not isinstance(indices, collections.abc.Sequence):
...@@ -168,7 +168,7 @@ class SequentialSampler(Sampler): ...@@ -168,7 +168,7 @@ class SequentialSampler(Sampler):
def sample(self) -> Iterator[Any]: def sample(self) -> Iterator[Any]:
r""" r"""
return a generator Return a generator.
""" """
if self.indices is None: if self.indices is None:
return iter(range(len(self.dataset))) return iter(range(len(self.dataset)))
...@@ -188,7 +188,7 @@ class RandomSampler(Sampler): ...@@ -188,7 +188,7 @@ class RandomSampler(Sampler):
seed=None, seed=None,
): ):
r""" r"""
Sample elements randomly without replacement Sample elements randomly without replacement.
""" """
super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed) super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed)
if indices is not None and not isinstance(indices, collections.abc.Sequence): if indices is not None and not isinstance(indices, collections.abc.Sequence):
...@@ -218,10 +218,10 @@ class ReplacementSampler(Sampler): ...@@ -218,10 +218,10 @@ class ReplacementSampler(Sampler):
seed=None, seed=None,
): ):
r""" r"""
Sample elements randomly with replacement Sample elements randomly with replacement.
:type weights: List :type weights: List
:param weights: weights for sampling indices, it could be unnormalized weights :param weights: weights for sampling indices, it could be unnormalized weights.
""" """
super().__init__( super().__init__(
dataset, batch_size, drop_last, num_samples, world_size, rank, seed dataset, batch_size, drop_last, num_samples, world_size, rank, seed
...@@ -250,7 +250,7 @@ class ReplacementSampler(Sampler): ...@@ -250,7 +250,7 @@ class ReplacementSampler(Sampler):
class Infinite(Sampler): class Infinite(Sampler):
r"""Infinite Sampler warper for basic sampler""" r"""Infinite Sampler warper for basic sampler."""
def sample(self): def sample(self):
raise NotImplementedError("sample method not supported in Infinite") raise NotImplementedError("sample method not supported in Infinite")
......
...@@ -12,7 +12,7 @@ from typing import Sequence, Tuple ...@@ -12,7 +12,7 @@ from typing import Sequence, Tuple
class Transform(ABC): class Transform(ABC):
""" """
rewrite apply method in subclass Rewrite apply method in subclass.
""" """
def apply_batch(self, inputs: Sequence[Tuple]): def apply_batch(self, inputs: Sequence[Tuple]):
......
...@@ -15,7 +15,7 @@ import numpy as np ...@@ -15,7 +15,7 @@ import numpy as np
def wrap_keepdims(func): def wrap_keepdims(func):
"""Wraper to keep the dimension of input images unchanged""" """Wraper to keep the dimension of input images unchanged."""
@functools.wraps(func) @functools.wraps(func)
def wrapper(image, *args, **kwargs): def wrapper(image, *args, **kwargs):
...@@ -34,10 +34,10 @@ def wrap_keepdims(func): ...@@ -34,10 +34,10 @@ def wrap_keepdims(func):
@wrap_keepdims @wrap_keepdims
def to_gray(image): def to_gray(image):
r""" r"""
Change BGR format image's color space to gray Change BGR format image's color space to gray.
:param image: Input BGR format image, with (H, W, C) shape :param image: input BGR format image, with `(H, W, C)` shape.
:return: Gray format image, with (H, W, C) shape :return: gray format image, with `(H, W, C)` shape.
""" """
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
...@@ -45,10 +45,10 @@ def to_gray(image): ...@@ -45,10 +45,10 @@ def to_gray(image):
@wrap_keepdims @wrap_keepdims
def to_bgr(image): def to_bgr(image):
r""" r"""
Change gray format image's color space to BGR Change gray format image's color space to BGR.
:param image: input Gray format image, with (H, W, C) shape :param image: input Gray format image, with `(H, W, C)` shape.
:return: BGR format image, with (H, W, C) shape :return: BGR format image, with `(H, W, C)` shape.
""" """
return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
...@@ -56,18 +56,18 @@ def to_bgr(image): ...@@ -56,18 +56,18 @@ def to_bgr(image):
@wrap_keepdims @wrap_keepdims
def pad(input, size, value): def pad(input, size, value):
r""" r"""
Pad input data with *value* and given *size* Pad input data with *value* and given *size*.
:param input: Input data, with (H, W, C) shape :param input: input data, with `(H, W, C)` shape.
:param size: Padding size of input data, it could be integer or sequence. :param size: padding size of input data, it could be integer or sequence.
If it's an integer, the input data will be padded in four directions. If it is an integer, the input data will be padded in four directions.
If it's a sequence contains two integer, the bottom and right side If it is a sequence contains two integer, the bottom and right side
of input data will be padded. of input data will be padded.
If it's a sequence contains four integer, the top, bottom, left, right If it is a sequence contains four integer, the top, bottom, left, right
side of input data will be padded with given size. side of input data will be padded with given size.
:param value: Padding value of data, could be a sequence of int or float. :param value: padding value of data, could be a sequence of int or float.
if it's float value, the dtype of image will be casted to float32 also. If it is float value, the dtype of image will be casted to float32 also.
:return: Padded image :return: padded image.
""" """
if isinstance(size, int): if isinstance(size, int):
size = (size, size, size, size) size = (size, size, size, size)
...@@ -81,14 +81,18 @@ def pad(input, size, value): ...@@ -81,14 +81,18 @@ def pad(input, size, value):
@wrap_keepdims @wrap_keepdims
def flip(image, flipCode): def flip(image, flipCode):
r""" r"""
Accordding to the flipCode (the type of flip), flip the input image Accordding to the flipCode (the type of flip), flip the input image.
:param image: Input image, with (H, W, C) shape :param image: input image, with `(H, W, C)` shape.
:param flipCode: code that indicates the type of flip. :param flipCode: code that indicates the type of flip.
1 : Flip horizontally
0 : Flip vertically * 1 : Flip horizontally
-1 : Flip horizontally and vertically
:return: BGR format image, with (H, W, C) shape * 0 : Flip vertically
* -1: Flip horizontally and vertically
:return: BGR format image, with `(H, W, C)` shape.
""" """
return cv2.flip(image, flipCode=flipCode) return cv2.flip(image, flipCode=flipCode)
...@@ -96,12 +100,12 @@ def flip(image, flipCode): ...@@ -96,12 +100,12 @@ def flip(image, flipCode):
@wrap_keepdims @wrap_keepdims
def resize(input, size, interpolation=cv2.INTER_LINEAR): def resize(input, size, interpolation=cv2.INTER_LINEAR):
r""" r"""
resize the input data to given size Resize the input data to given size.
:param input: Input data, could be image or masks, with (H, W, C) shape :param input: input data, could be image or masks, with `(H, W, C)` shape.
:param size: Target size of input data, with (height, width) shape. :param size: target size of input data, with (height, width) shape.
:param interpolation: Interpolation method. :param interpolation: interpolation method.
:return: Resized data, with (H, W, C) shape :return: resized data, with `(H, W, C)` shape.
""" """
if len(size) != 2: if len(size) != 2:
raise ValueError("resize needs (h, w), but got {}".format(size)) raise ValueError("resize needs (h, w), but got {}".format(size))
......
...@@ -44,26 +44,26 @@ __all__ = [ ...@@ -44,26 +44,26 @@ __all__ = [
class VisionTransform(Transform): class VisionTransform(Transform):
r""" r"""
Base class of all transforms used in computer vision. Base class of all transforms used in computer vision.
calling logic: apply_batch() -> apply() -> _apply_image() and other _apply_*() Calling logic: apply_batch() -> apply() -> _apply_image() and other _apply_*()
method. If you want to implement a self-defined transform method for image, method. If you want to implement a self-defined transform method for image,
rewrite _apply_image method in subclass. rewrite _apply_image method in subclass.
:param order: Input type order. Input is a tuple contains different structures, :param order: input type order. Input is a tuple containing different structures,
order is used to specify the order of structures. For example, if your input order is used to specify the order of structures. For example, if your input
is (image, boxes) type, then the order should be ("image", "boxes"). is (image, boxes) type, then the ``order`` should be ("image", "boxes").
Current available strings & data type are describe below: Current available strings and data type are describe below:
* "image": input image, with shape of (H, W, C) * "image": input image, with shape of `(H, W, C)`.
* "coords": coordinates, with shape of (N, 2) * "coords": coordinates, with shape of `(N, 2)`.
* "boxes": bounding boxes, with shape of (N, 4), "xyxy" format, * "boxes": bounding boxes, with shape of `(N, 4)`, "xyxy" format,
the 1st "xy" represents top left point of a box, the 1st "xy" represents top left point of a box,
the 2nd "xy" represents right bottom point. the 2nd "xy" represents right bottom point.
* "mask": map used for segmentation, with shape of (H, W, 1) * "mask": map used for segmentation, with shape of `(H, W, 1)`.
* "keypoints": keypoints with shape of (N, K, 3), N for number of instances, * "keypoints": keypoints with shape of `(N, K, 3)`, N for number of instances,
and K for number of keypoints in one instance. The first two dimensions and K for number of keypoints in one instance. The first two dimensions
of last axis is coordinate of keypoints and the the 3rd dimension is of last axis is coordinate of keypoints and the the 3rd dimension is
the label of keypoints. the label of keypoints.
* "polygons": A sequence contains numpy array, its length is number of instances. * "polygons": a sequence containing numpy arrays, its length is the number of instances.
Each numpy array represents polygon coordinate of one instance. Each numpy array represents polygon coordinate of one instance.
* "category": categories for some data type. For example, "image_category" * "category": categories for some data type. For example, "image_category"
means category of the input image and "boxes_category" means categories of means category of the input image and "boxes_category" means categories of
...@@ -94,11 +94,11 @@ class VisionTransform(Transform): ...@@ -94,11 +94,11 @@ class VisionTransform(Transform):
self.order = order self.order = order
def apply_batch(self, inputs: Sequence[Tuple]): def apply_batch(self, inputs: Sequence[Tuple]):
r"""Apply transform on batch input data""" r"""Apply transform on batch input data."""
return tuple(self.apply(input) for input in inputs) return tuple(self.apply(input) for input in inputs)
def apply(self, input: Tuple): def apply(self, input: Tuple):
r"""Apply transform on single input data""" r"""Apply transform on single input data."""
if not isinstance(input, tuple): if not isinstance(input, tuple):
input = (input,) input = (input,)
...@@ -156,10 +156,10 @@ class VisionTransform(Transform): ...@@ -156,10 +156,10 @@ class VisionTransform(Transform):
class ToMode(VisionTransform): class ToMode(VisionTransform):
r"""Change input data to a target mode. r"""Change input data to a target mode.
For example, most transforms use HWC mode image, For example, most transforms use HWC mode image,
while the Neural Network might use CHW mode input tensor while the neural network might use CHW mode input tensor.
:param mode: Output mode of input. Use "CHW" mode by default. :param mode: output mode of input. Default: "CHW"
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`
""" """
def __init__(self, mode="CHW", *, order=None): def __init__(self, mode="CHW", *, order=None):
...@@ -185,14 +185,14 @@ class Compose(VisionTransform): ...@@ -185,14 +185,14 @@ class Compose(VisionTransform):
r""" r"""
Composes several transforms together. Composes several transforms together.
:param transforms: List of :class:`VisionTransform` to compose. :param transforms: list of :class:`VisionTransform` to compose.
:param batch_compose: Whether use shuffle_indices for batch data or not. :param batch_compose: whether use shuffle_indices for batch data or not.
If True, use original input sequence. If True, use original input sequence.
Otherwise, the shuffle_indices will be used for transforms. Otherwise, the shuffle_indices will be used for transforms.
:param shuffle_indices: Indices used for random shuffle, start at 1. :param shuffle_indices: indices used for random shuffle, start at 1.
For example, if shuffle_indices is [(1, 3), (2, 4)], then the 1st and 3rd transform For example, if shuffle_indices is [(1, 3), (2, 4)], then the 1st and 3rd transform
will be random shuffled, the 2nd and 4th transform will also be shuffled. will be random shuffled, the 2nd and 4th transform will also be shuffled.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`
Examples: Examples:
...@@ -264,8 +264,8 @@ class TorchTransformCompose(VisionTransform): ...@@ -264,8 +264,8 @@ class TorchTransformCompose(VisionTransform):
some transforms with tensor in torchvision are not supported, some transforms with tensor in torchvision are not supported,
such as Normalize and ToTensor in torchvision. such as Normalize and ToTensor in torchvision.
:param transforms: The same with ``Compose`` :param transforms: the same with ``Compose``.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, transforms, *, order=None): def __init__(self, transforms, *, order=None):
...@@ -303,16 +303,16 @@ class TorchTransformCompose(VisionTransform): ...@@ -303,16 +303,16 @@ class TorchTransformCompose(VisionTransform):
class Pad(VisionTransform): class Pad(VisionTransform):
r"""Pad the input data. r"""Pad the input data.
:param size: Padding size of input image, it could be integer or sequence. :param size: padding size of input image, it could be integer or sequence.
If it's an integer, the input image will be padded in four directions. If it is an integer, the input image will be padded in four directions.
If it's a sequence contains two integer, the bottom and right side If it is a sequence containing two integers, the bottom and right side
of image will be padded. of image will be padded.
If it's a sequence contains four integer, the top, bottom, left, right If it is a sequence containing four integers, the top, bottom, left, right
side of image will be padded with given size. side of image will be padded with given size.
:param value: Padding value of image, could be a sequence of int or float. :param value: padding value of image, could be a sequence of int or float.
if it's float value, the dtype of image will be casted to float32 also. if it is float value, the dtype of image will be casted to float32 also.
:param mask_value: Padding value of segmentation map. :param mask_value: padding value of segmentation map.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, size=0, value=0, mask_value=0, *, order=None): def __init__(self, size=0, value=0, mask_value=0, *, order=None):
...@@ -350,15 +350,15 @@ class Pad(VisionTransform): ...@@ -350,15 +350,15 @@ class Pad(VisionTransform):
class Resize(VisionTransform): class Resize(VisionTransform):
r"""Resize the input data. r"""Resize the input data.
:param output_size: Target size of image, with (height, width) shape. :param output_size: target size of image, with (height, width) shape.
:param interpolation: Interpolation method. All methods are listed below: :param interpolation: interpolation method. All methods are listed below:
* cv2.INTER_NEAREST – a nearest-neighbor interpolation. * cv2.INTER_NEAREST – a nearest-neighbor interpolation.
* cv2.INTER_LINEAR – a bilinear interpolation (used by default). * cv2.INTER_LINEAR – a bilinear interpolation (used by default).
* cv2.INTER_AREA – resampling using pixel area relation. * cv2.INTER_AREA – resampling using pixel area relation.
* cv2.INTER_CUBIC – a bicubic interpolation over 4×4 pixel neighborhood. * cv2.INTER_CUBIC – a bicubic interpolation over 4×4 pixel neighborhood.
* cv2.INTER_LANCZOS4 – a Lanczos interpolation over 8×8 pixel neighborhood. * cv2.INTER_LANCZOS4 – a Lanczos interpolation over 8×8 pixel neighborhood.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, output_size, interpolation=cv2.INTER_LINEAR, *, order=None): def __init__(self, output_size, interpolation=cv2.INTER_LINEAR, *, order=None):
...@@ -476,8 +476,8 @@ class ShortestEdgeResize(VisionTransform): ...@@ -476,8 +476,8 @@ class ShortestEdgeResize(VisionTransform):
class RandomResize(VisionTransform): class RandomResize(VisionTransform):
r"""Resize the input data randomly. r"""Resize the input data randomly.
:param scale_range: . :param scale_range: range of scaling.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, scale_range, interpolation=cv2.INTER_LINEAR, *, order=None): def __init__(self, scale_range, interpolation=cv2.INTER_LINEAR, *, order=None):
...@@ -519,13 +519,13 @@ class RandomResize(VisionTransform): ...@@ -519,13 +519,13 @@ class RandomResize(VisionTransform):
class RandomCrop(VisionTransform): class RandomCrop(VisionTransform):
r"""Crop the input data randomly. Before applying the crop transform, r"""Crop the input data randomly. Before applying the crop transform,
pad the image first. And if target size is still bigger than the size of pad the image first. If target size is still bigger than the size of
padded image, pad the image size to target size. padded image, pad the image size to target size.
:param output_size: Target size of output image, with (height, width) shape. :param output_size: target size of output image, with (height, width) shape.
:param padding_size: The same with `size` in ``Pad`` :param padding_size: the same with `size` in ``Pad``.
:param padding_value: The same with `value` in ``Pad`` :param padding_value: the same with `value` in ``Pad``.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__( def __init__(
...@@ -580,10 +580,10 @@ class RandomResizedCrop(VisionTransform): ...@@ -580,10 +580,10 @@ class RandomResizedCrop(VisionTransform):
aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made. aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made.
After applying crop transfrom, the input data will be resized to given size. After applying crop transfrom, the input data will be resized to given size.
:param output_size: Target size of output image, with (height, width) shape. :param output_size: target size of output image, with (height, width) shape.
:param scale_range: Range of size of the origin size cropped. Default: (0.08, 1.0) :param scale_range: range of size of the origin size cropped. Default: (0.08, 1.0)
:param ratio_range: Range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33) :param ratio_range: range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__( def __init__(
...@@ -666,8 +666,8 @@ class RandomResizedCrop(VisionTransform): ...@@ -666,8 +666,8 @@ class RandomResizedCrop(VisionTransform):
class CenterCrop(VisionTransform): class CenterCrop(VisionTransform):
r"""Crops the given the input data at the center. r"""Crops the given the input data at the center.
:param output_size: Target size of output image, with (height, width) shape. :param output_size: target size of output image, with (height, width) shape.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, output_size, *, order=None): def __init__(self, output_size, *, order=None):
...@@ -710,7 +710,7 @@ class RandomHorizontalFlip(VisionTransform): ...@@ -710,7 +710,7 @@ class RandomHorizontalFlip(VisionTransform):
r"""Horizontally flip the input data randomly with a given probability. r"""Horizontally flip the input data randomly with a given probability.
:param p: probability of the input data being flipped. Default: 0.5 :param p: probability of the input data being flipped. Default: 0.5
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, prob: float = 0.5, *, order=None): def __init__(self, prob: float = 0.5, *, order=None):
...@@ -742,7 +742,7 @@ class RandomVerticalFlip(VisionTransform): ...@@ -742,7 +742,7 @@ class RandomVerticalFlip(VisionTransform):
r"""Vertically flip the input data randomly with a given probability. r"""Vertically flip the input data randomly with a given probability.
:param p: probability of the input data being flipped. Default: 0.5 :param p: probability of the input data being flipped. Default: 0.5
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, prob: float = 0.5, *, order=None): def __init__(self, prob: float = 0.5, *, order=None):
...@@ -776,9 +776,9 @@ class Normalize(VisionTransform): ...@@ -776,9 +776,9 @@ class Normalize(VisionTransform):
this transform will normalize each channel of the input data. this transform will normalize each channel of the input data.
``output[channel] = (input[channel] - mean[channel]) / std[channel]`` ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
:param mean: Sequence of means for each channel. :param mean: sequence of means for each channel.
:param std: Sequence of standard deviations for each channel. :param std: sequence of standard deviations for each channel.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, mean=0.0, std=1.0, *, order=None): def __init__(self, mean=0.0, std=1.0, *, order=None):
...@@ -802,7 +802,7 @@ class GaussianNoise(VisionTransform): ...@@ -802,7 +802,7 @@ class GaussianNoise(VisionTransform):
:param mean: Gaussian mean used to generate noise. :param mean: Gaussian mean used to generate noise.
:param std: Gaussian standard deviation used to generate noise. :param std: Gaussian standard deviation used to generate noise.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`
""" """
def __init__(self, mean=0.0, std=1.0, *, order=None): def __init__(self, mean=0.0, std=1.0, *, order=None):
...@@ -826,9 +826,9 @@ class GaussianNoise(VisionTransform): ...@@ -826,9 +826,9 @@ class GaussianNoise(VisionTransform):
class BrightnessTransform(VisionTransform): class BrightnessTransform(VisionTransform):
r"""Adjust brightness of the input data. r"""Adjust brightness of the input data.
:param value: How much to adjust the brightness. Can be any :param value: how much to adjust the brightness. Can be any
non negative number. 0 gives the original image non negative number. 0 gives the original image.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, value, *, order=None): def __init__(self, value, *, order=None):
...@@ -857,9 +857,9 @@ class BrightnessTransform(VisionTransform): ...@@ -857,9 +857,9 @@ class BrightnessTransform(VisionTransform):
class ContrastTransform(VisionTransform): class ContrastTransform(VisionTransform):
r"""Adjust contrast of the input data. r"""Adjust contrast of the input data.
:param value: How much to adjust the contrast. Can be any :param value: how much to adjust the contrast. Can be any
non negative number. 0 gives the original image non negative number. 0 gives the original image.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, value, *, order=None): def __init__(self, value, *, order=None):
...@@ -888,9 +888,9 @@ class ContrastTransform(VisionTransform): ...@@ -888,9 +888,9 @@ class ContrastTransform(VisionTransform):
class SaturationTransform(VisionTransform): class SaturationTransform(VisionTransform):
r"""Adjust saturation of the input data. r"""Adjust saturation of the input data.
:param value: How much to adjust the saturation. Can be any :param value: how much to adjust the saturation. Can be any
non negative number. 0 gives the original image non negative number. 0 gives the original image.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, value, *, order=None): def __init__(self, value, *, order=None):
...@@ -919,9 +919,9 @@ class SaturationTransform(VisionTransform): ...@@ -919,9 +919,9 @@ class SaturationTransform(VisionTransform):
class HueTransform(VisionTransform): class HueTransform(VisionTransform):
r"""Adjust hue of the input data. r"""Adjust hue of the input data.
:param value: How much to adjust the hue. Can be any number :param value: how much to adjust the hue. Can be any number
between 0 and 0.5, 0 gives the original image between 0 and 0.5, 0 gives the original image.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, value, *, order=None): def __init__(self, value, *, order=None):
...@@ -957,19 +957,19 @@ class HueTransform(VisionTransform): ...@@ -957,19 +957,19 @@ class HueTransform(VisionTransform):
class ColorJitter(VisionTransform): class ColorJitter(VisionTransform):
r"""Randomly change the brightness, contrast, saturation and hue of an image. r"""Randomly change the brightness, contrast, saturation and hue of an image.
:param brightness: How much to jitter brightness. :param brightness: how much to jitter brightness.
Chosen uniformly from [max(0, 1 - brightness), 1 + brightness] Chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers. or the given [min, max]. Should be non negative numbers.
:param contrast: How much to jitter contrast. :param contrast: how much to jitter contrast.
Chosen uniformly from [max(0, 1 - contrast), 1 + contrast] Chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non negative numbers. or the given [min, max]. Should be non negative numbers.
:param saturation: How much to jitter saturation. :param saturation: how much to jitter saturation.
Chosen uniformly from [max(0, 1 - saturation), 1 + saturation] Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers. or the given [min, max]. Should be non negative numbers.
:param hue: How much to jitter hue. :param hue: how much to jitter hue.
Chosen uniformly from [-hue, hue] or the given [min, max]. Chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
:param order: The same with :class:`VisionTransform` :param order: the same with :class:`VisionTransform`.
""" """
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, *, order=None): def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, *, order=None):
......
...@@ -71,11 +71,11 @@ def set_default_device(device: str = "xpux"): ...@@ -71,11 +71,11 @@ def set_default_device(device: str = "xpux"):
'multithread' device type is avaliable when inference, which implements 'multithread' device type is avaliable when inference, which implements
multi-threading parallelism at the operator level. For example, multi-threading parallelism at the operator level. For example,
'multithread4' will compute with 4 threads. which implements 'multithread4' will compute with 4 threads.
The default value is 'xpux' to specify any device available. The priority of using gpu is higher when both gpu and cpu are available. The default value is 'xpux' to specify any device available. The priority of using gpu is higher when both gpu and cpu are available.
It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. It can also be set by environment variable `MGE_DEFAULT_DEVICE`.
""" """
assert _valid_device(device), "Invalid device name {}".format(device) assert _valid_device(device), "Invalid device name {}".format(device)
CompNode._set_default_device(device) CompNode._set_default_device(device)
...@@ -99,13 +99,13 @@ def set_prealloc_config( ...@@ -99,13 +99,13 @@ def set_prealloc_config(
growth_factor=2.0, growth_factor=2.0,
device_type=DeviceType.CUDA, device_type=DeviceType.CUDA,
): ):
"""specifies how to pre-allocate from raw dev allocator """Specifies how to pre-allocate from raw device allocator.
:param alignment: specifies the alignment in bytes. :param alignment: specifies the alignment in bytes.
:param min_req: min request size in bytes. :param min_req: min request size in bytes.
:param max_overhead: max overhead above required size in bytes. :param max_overhead: max overhead above required size in bytes.
:growth_factor: request size / cur allocated :param growth_factor: `request size / cur allocated`
:device_type: the device type :param device_type: the device type
""" """
assert alignment > 0 assert alignment > 0
......
...@@ -102,7 +102,7 @@ def _(op: RemoteRecv): ...@@ -102,7 +102,7 @@ def _(op: RemoteRecv):
def collective_comm(inp, mode, group, device): def collective_comm(inp, mode, group, device):
"""Helper function for applying collective communication functions""" """Helper function for applying collective communication functions."""
assert isinstance(group, Group) assert isinstance(group, Group)
if group is None: if group is None:
return inp return inp
...@@ -123,11 +123,11 @@ def collective_comm(inp, mode, group, device): ...@@ -123,11 +123,11 @@ def collective_comm(inp, mode, group, device):
def reduce_sum( def reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
"""Create reduce_sum operator for collective communication """Create reduce_sum operator for collective communication.
:param inp: input tensor :param inp: input tensor.
:param group: communication group :param group: communication group.
:param device: execute placement :param device: execution device.
""" """
mode = CollectiveCommMode.REDUCE_SUM mode = CollectiveCommMode.REDUCE_SUM
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -136,11 +136,11 @@ def reduce_sum( ...@@ -136,11 +136,11 @@ def reduce_sum(
def broadcast( def broadcast(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
"""Create broadcast operator for collective communication """Create broadcast operator for collective communication.
:param inp: input tensor :param inp: input tensor.
:param group: communication group :param group: communication group.
:param device: execute placement :param device: execution device.
""" """
mode = CollectiveCommMode.BROADCAST mode = CollectiveCommMode.BROADCAST
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -149,11 +149,11 @@ def broadcast( ...@@ -149,11 +149,11 @@ def broadcast(
def all_gather( def all_gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
"""Create all_gather operator for collective communication """Create all_gather operator for collective communication.
:param inp: input tensor :param inp: input tensor.
:param group: communication group :param group: communication group.
:param device: execute placement :param device: execution device.
""" """
mode = CollectiveCommMode.ALL_GATHER mode = CollectiveCommMode.ALL_GATHER
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -162,11 +162,11 @@ def all_gather( ...@@ -162,11 +162,11 @@ def all_gather(
def reduce_scatter_sum( def reduce_scatter_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
"""Create reduce_scatter_sum operator for collective communication """Create reduce_scatter_sum operator for collective communication.
:param inp: input tensor :param inp: input tensor.
:param group: communication group :param group: communication group.
:param device: execute placement :param device: execution device.
""" """
mode = CollectiveCommMode.REDUCE_SCATTER_SUM mode = CollectiveCommMode.REDUCE_SCATTER_SUM
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -175,11 +175,11 @@ def reduce_scatter_sum( ...@@ -175,11 +175,11 @@ def reduce_scatter_sum(
def all_reduce_sum( def all_reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
"""Create all_reduce_sum operator for collective communication """Create all_reduce_sum operator for collective communication.
:param inp: input tensor :param inp: input tensor.
:param group: communication group :param group: communication group.
:param device: execute placement :param device: execution device.
""" """
mode = CollectiveCommMode.ALL_REDUCE_SUM mode = CollectiveCommMode.ALL_REDUCE_SUM
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -188,11 +188,11 @@ def all_reduce_sum( ...@@ -188,11 +188,11 @@ def all_reduce_sum(
def all_reduce_max( def all_reduce_max(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
"""Create all_reduce_max operator for collective communication """Create all_reduce_max operator for collective communication.
:param inp: input tensor :param inp: input tensor.
:param group: communication group :param group: communication group.
:param device: execute placement :param device: execution device.
""" """
mode = CollectiveCommMode.ALL_REDUCE_MAX mode = CollectiveCommMode.ALL_REDUCE_MAX
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -201,11 +201,11 @@ def all_reduce_max( ...@@ -201,11 +201,11 @@ def all_reduce_max(
def all_reduce_min( def all_reduce_min(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
"""Create all_reduce_min operator for collective communication """Create all_reduce_min operator for collective communication.
:param inp: input tensor :param inp: input tensor.
:param group: communication group :param group: communication group.
:param device: execute placement :param device: execution device.
""" """
mode = CollectiveCommMode.ALL_REDUCE_MIN mode = CollectiveCommMode.ALL_REDUCE_MIN
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -214,11 +214,11 @@ def all_reduce_min( ...@@ -214,11 +214,11 @@ def all_reduce_min(
def gather( def gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
"""Create gather operator for collective communication """Create gather operator for collective communication.
:param inp: input tensor :param inp: input tensor.
:param group: communication group :param group: communication group.
:param device: execute placement :param device: execution device.
""" """
mode = CollectiveCommMode.GATHER mode = CollectiveCommMode.GATHER
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -227,11 +227,11 @@ def gather( ...@@ -227,11 +227,11 @@ def gather(
def scatter( def scatter(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
"""Create scatter operator for collective communication """Create scatter operator for collective communication.
:param inp: input tensor :param inp: input tensor.
:param group: communication group :param group: communication group.
:param device: execute placement :param device: execution device.
""" """
mode = CollectiveCommMode.SCATTER mode = CollectiveCommMode.SCATTER
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -240,21 +240,21 @@ def scatter( ...@@ -240,21 +240,21 @@ def scatter(
def all_to_all( def all_to_all(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
"""Create all_to_all operator for collective communication """Create all_to_all operator for collective communication.
:param inp: input tensor :param inp: input tensor.
:param group: communication group :param group: communication group.
:param device: execute placement :param device: execution device.
""" """
mode = CollectiveCommMode.ALL_TO_ALL mode = CollectiveCommMode.ALL_TO_ALL
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
def remote_send(inp: Tensor, dest_rank: int) -> Tensor: def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
"""Send a Tensor to a remote process """Send a Tensor to a remote process.
:param inp: tensor to send :param inp: tensor to send.
:param dest_rank: destination process rank :param dest_rank: destination process rank.
""" """
op = RemoteSend() op = RemoteSend()
op.key = "{}->{}".format(get_rank(), dest_rank) op.key = "{}->{}".format(get_rank(), dest_rank)
...@@ -266,12 +266,12 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: ...@@ -266,12 +266,12 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
def remote_recv( def remote_recv(
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None
) -> Tensor: ) -> Tensor:
"""Receive a Tensor from a remote process """Receive a Tensor from a remote process.
:param src_rank: source process rank :param src_rank: source process rank.
:param shape: the shape of the tensor to receive :param shape: the shape of the tensor to receive.
:param dtype: the data type of the tensor to receive :param dtype: the data type of the tensor to receive.
:param device: the device to place the received tensor :param device: the device to place the received tensor.
""" """
key = "{}->{}".format(src_rank, get_rank()) key = "{}->{}".format(src_rank, get_rank())
......
...@@ -83,12 +83,12 @@ def init_process_group( ...@@ -83,12 +83,12 @@ def init_process_group(
) -> None: ) -> None:
"""Initialize the distributed process group and specify the device used in the current process """Initialize the distributed process group and specify the device used in the current process
:param master_ip: IP address of the master node :param master_ip: ip address of the master node.
:param port: Port available for all processes to communicate :param port: port available for all processes to communicate.
:param world_size: Total number of processes participating in the job :param world_size: total number of processes participating in the job.
:param rank: Rank of the current process :param rank: rank of the current process.
:param device: The GPU device id to bind this process to :param device: the GPU device id to bind this process to.
:param backend: Communicator backend, currently support 'nccl' and 'ucx' :param backend: communicator backend, currently support 'nccl' and 'ucx'.
""" """
if not isinstance(master_ip, str): if not isinstance(master_ip, str):
raise TypeError("Expect type str but got {}".format(type(master_ip))) raise TypeError("Expect type str but got {}".format(type(master_ip)))
...@@ -127,50 +127,50 @@ def init_process_group( ...@@ -127,50 +127,50 @@ def init_process_group(
def is_distributed() -> bool: def is_distributed() -> bool:
"""Return True if the distributed process group has been initialized""" """Return True if the distributed process group has been initialized."""
return _sd is not None return _sd is not None
def get_rank() -> int: def get_rank() -> int:
"""Get the rank of the current process""" """Get the rank of the current process."""
return _sd.proc_rank if _sd is not None else 0 return _sd.proc_rank if _sd is not None else 0
def get_world_size() -> int: def get_world_size() -> int:
"""Get the total number of processes participating in the job""" """Get the total number of processes participating in the job."""
return _sd.world_size if _sd is not None else 1 return _sd.world_size if _sd is not None else 1
def get_backend() -> str: def get_backend() -> str:
"""Get the backend str""" """Get the backend str."""
assert _sd is not None, "please call init_process_group first" assert _sd is not None, "please call init_process_group first"
return _sd.backend if _sd is not None else None return _sd.backend if _sd is not None else None
def get_py_server_addr() -> Tuple[str, int]: def get_py_server_addr() -> Tuple[str, int]:
"""Get master_ip and port of python XML RPC server""" """Get master_ip and port of python XML RPC server."""
assert _sd is not None, "please call init_process_group first" assert _sd is not None, "please call init_process_group first"
return _sd.master_ip, _sd.py_server_port return _sd.master_ip, _sd.py_server_port
def get_mm_server_addr() -> Tuple[str, int]: def get_mm_server_addr() -> Tuple[str, int]:
"""Get master_ip and port of C++ mm_server""" """Get master_ip and port of C++ mm_server."""
assert _sd is not None, "please call init_process_group first" assert _sd is not None, "please call init_process_group first"
return _sd.master_ip, _sd.mm_server_port return _sd.master_ip, _sd.mm_server_port
def get_client() -> Client: def get_client() -> Client:
"""Get client of python XML RPC server""" """Get client of python XML RPC server."""
assert _sd is not None, "please call init_process_group first" assert _sd is not None, "please call init_process_group first"
return _sd.client return _sd.client
def new_group(proc_ranks: List[int]) -> Group: def new_group(proc_ranks: List[int]) -> Group:
"""Build a subgroup containing certain ranks""" """Build a subgroup containing certain ranks."""
return Group(proc_ranks) return Group(proc_ranks)
def group_barrier(group: Optional[Group] = WORLD) -> None: def group_barrier(group: Optional[Group] = WORLD) -> None:
"""Block until all ranks in the group reach this barrier""" """Block until all ranks in the group reach this barrier."""
assert isinstance(group, Group) assert isinstance(group, Group)
_sd.client.group_barrier(group.key, group.size) _sd.client.group_barrier(group.key, group.size)
...@@ -15,7 +15,7 @@ from .util import get_free_ports ...@@ -15,7 +15,7 @@ from .util import get_free_ports
def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs): def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs):
"""init distributed process group and run wrapped function""" """Init distributed process group and run wrapped function."""
init_process_group( init_process_group(
master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev
) )
...@@ -23,7 +23,7 @@ def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs): ...@@ -23,7 +23,7 @@ def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs):
def launcher(func): def launcher(func):
"""decorator for launching multiple processes in single-machine multi-gpu training""" """Decorator for launching multiple processes in single-machine multi-gpu training."""
n_gpus = get_device_count_by_fork("gpu") n_gpus = get_device_count_by_fork("gpu")
......
...@@ -26,14 +26,14 @@ def set_conv_execution_strategy(option: str): ...@@ -26,14 +26,14 @@ def set_conv_execution_strategy(option: str):
Available values: Available values:
* 'HEURISTIC' uses heuristic to choose the fastest algorithm. * 'HEURISTIC' uses heuristic to choose the fastest algorithm.
* 'PROFILE' runs possible algorithms on real device to find the best. * 'PROFILE' runs possible algorithms on real device to find the best one.
* 'PROFILE_HEURISTIC' uses profile result and heuristic to choose the fastest algorithm. * 'PROFILE_HEURISTIC' uses profiling result and heuristic to choose the fastest algorithm.
* 'PROFILE_REPRODUCIBLE' uses the fastest of profile result that is also reproducible. * 'PROFILE_REPRODUCIBLE' uses the fastest of profiling result that is also reproducible.
* 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible. * 'HEURISTIC_REPRODUCIBLE' uses heuristic to choose the fastest algorithm that is also reproducible.
The default strategy is 'HEURISTIC'. The default strategy is 'HEURISTIC'.
It can also be set through the environmental variable 'MEGENGINE_CONV_EXECUTION_STRATEGY'. It can also be set through the environment variable 'MEGENGINE_CONV_EXECUTION_STRATEGY'.
""" """
valid_option = ( valid_option = (
"HEURISTIC", "HEURISTIC",
......
...@@ -99,8 +99,9 @@ def _elemwise_multi_type(*args, mode, **kwargs): ...@@ -99,8 +99,9 @@ def _elemwise_multi_type(*args, mode, **kwargs):
def add(x, y): def add(x, y):
"""Element-wise addition. """Element-wise `addition`.
At least one operand should be tensor. At least one operand should be tensor.
Same for sub/mul/div/floor_div/pow/mod/atan2/eq/ne/lt/le/gt/ge/maximum/minmium. Same for sub/mul/div/floor_div/pow/mod/atan2/eq/ne/lt/le/gt/ge/maximum/minmium.
:param x: input tensor. :param x: input tensor.
...@@ -131,68 +132,68 @@ def add(x, y): ...@@ -131,68 +132,68 @@ def add(x, y):
def sub(x, y): def sub(x, y):
"""Element-wise subtraction.""" """Element-wise `subtraction`."""
return _elwise(x, y, mode="sub") return _elwise(x, y, mode="sub")
def mul(x, y): def mul(x, y):
"""Element-wise multiplication.""" """Element-wise `multiplication`."""
return _elwise(x, y, mode="mul") return _elwise(x, y, mode="mul")
def div(x, y): def div(x, y):
"""Element-wise (x / y).""" """Element-wise `(x / y)`."""
return _elwise(x, y, mode="true_div") return _elwise(x, y, mode="true_div")
def floor_div(x, y): def floor_div(x, y):
"""Element-wise floor(x / y).""" """Element-wise `floor(x / y)`."""
return _elwise(x, y, mode="floor_divide") return _elwise(x, y, mode="floor_divide")
def neg(x): def neg(x):
"""Element-wise negation.""" """Element-wise `negation`."""
return _elwise(x, mode="negate") return _elwise(x, mode="negate")
def pow(x, y): def pow(x, y):
"""Element-wise power.""" """Element-wise `power`."""
return _elwise(x, y, mode="pow") return _elwise(x, y, mode="pow")
def mod(x, y): def mod(x, y):
"""Element-wise remainder of division.""" """Element-wise `remainder of division`."""
return _elwise(x, y, mode="mod") return _elwise(x, y, mode="mod")
def abs(x): def abs(x):
"""Element-wise absolute value.""" """Element-wise `absolute value`."""
return _elwise(x, mode="abs") return _elwise(x, mode="abs")
def exp(x): def exp(x):
"""Element-wise exponential.""" """Element-wise `exponential`."""
return _elwise(x, mode="exp") return _elwise(x, mode="exp")
def expm1(x): def expm1(x):
"""Element-wise exp(x)-1.""" """Element-wise `exp(x)-1`."""
return _elwise(x, mode="expm1") return _elwise(x, mode="expm1")
def log(x): def log(x):
"""Element-wise logarithm (base `e`).""" """Element-wise `logarithm (base e)`."""
return _elwise(x, mode="log") return _elwise(x, mode="log")
def log1p(x): def log1p(x):
"""Element-wise log(x+1) (base `e`).""" """Element-wise `log(x+1) (base e)`."""
return _elwise(x, mode="log1p") return _elwise(x, mode="log1p")
def sqrt(x: Tensor) -> Tensor: def sqrt(x: Tensor) -> Tensor:
"""Element-wise sqrt. """Element-wise `sqrt`.
For negative input value, return ``NaN``. Returns ``NaN`` for negative input value.
:param x: input tensor. :param x: input tensor.
:return: computed tensor. :return: computed tensor.
...@@ -222,7 +223,7 @@ def sqrt(x: Tensor) -> Tensor: ...@@ -222,7 +223,7 @@ def sqrt(x: Tensor) -> Tensor:
def square(x: Tensor) -> Tensor: def square(x: Tensor) -> Tensor:
""" """
Return a new tensor with the square of the elements of input tensor. Returns a new tensor with the square of the elements of input tensor.
:param inp: The input tensor :param inp: The input tensor
:return: The computed tensor :return: The computed tensor
...@@ -251,27 +252,27 @@ def square(x: Tensor) -> Tensor: ...@@ -251,27 +252,27 @@ def square(x: Tensor) -> Tensor:
def round(x): def round(x):
"""Element-wise rounding to int.""" """Element-wise `rounding to int`."""
return _elwise(x, mode="round") return _elwise(x, mode="round")
def ceil(x): def ceil(x):
"""Element-wise ceiling.""" """Element-wise `ceiling`."""
return _elwise(x, mode="ceil") return _elwise(x, mode="ceil")
def floor(x): def floor(x):
"""Element-wise floor.""" """Element-wise `floor`."""
return _elwise(x, mode="floor") return _elwise(x, mode="floor")
def maximum(x, y): def maximum(x, y):
"""Element-wise maximum of array elements.""" """Element-wise `maximum of array elements`."""
return _elwise(x, y, mode="max") return _elwise(x, y, mode="max")
def minimum(x, y): def minimum(x, y):
"""Element-wise minimum of array elements.""" """Element-wise `minimum of array elements`."""
return _elwise(x, y, mode="min") return _elwise(x, y, mode="min")
...@@ -279,7 +280,7 @@ def minimum(x, y): ...@@ -279,7 +280,7 @@ def minimum(x, y):
def cos(x): def cos(x):
"""Element-wise cosine. """Element-wise `cosine`.
:param x: input tensor. :param x: input tensor.
:return: computed tensor. :return: computed tensor.
...@@ -308,68 +309,68 @@ def cos(x): ...@@ -308,68 +309,68 @@ def cos(x):
def sin(x): def sin(x):
"""Element-wise sine.""" """Element-wise `sine`."""
return _elwise(x, mode="sin") return _elwise(x, mode="sin")
def tan(x): def tan(x):
"""Element-wise tangent.""" """Element-wise `tangent`."""
return sin(x) / cos(x) return sin(x) / cos(x)
def acos(x): def acos(x):
"""Element-wise inverse cosine.""" """Element-wise `inverse cosine`."""
return _elwise(x, mode="acos") return _elwise(x, mode="acos")
def asin(x): def asin(x):
"""Element-wise inverse sine.""" """Element-wise `inverse sine`."""
return _elwise(x, mode="asin") return _elwise(x, mode="asin")
def atan(x): def atan(x):
"""Element-wise inverse tangent.""" """Element-wise `inverse tangent`."""
return _elwise(x, 1, mode="atan2") return _elwise(x, 1, mode="atan2")
def atan2(y, x): def atan2(y, x):
"""Element-wise 2-argument arctangent.""" """Element-wise `2-argument arctangent`."""
return _elwise(y, x, mode="atan2") return _elwise(y, x, mode="atan2")
def cosh(x): def cosh(x):
r"""Element-wise hyperbolic cosine.""" r"""Element-wise `hyperbolic cosine`."""
return 0.5 * (exp(x) + exp(-x)) return 0.5 * (exp(x) + exp(-x))
def sinh(x): def sinh(x):
r"""Element-wise hyperbolic sine.""" r"""Element-wise `hyperbolic sine`."""
u = expm1(x) u = expm1(x)
return 0.5 * u / (u + 1) * (u + 2) return 0.5 * u / (u + 1) * (u + 2)
def tanh(x): def tanh(x):
r"""Element-wise hyperbolic tangent.""" r"""Element-wise `hyperbolic tangent`."""
return _elwise(x, mode="tanh") return _elwise(x, mode="tanh")
def asinh(x): def asinh(x):
r"""Element-wise inverse hyperbolic sine.""" r"""Element-wise `inverse hyperbolic sine`."""
return log(x + (x ** 2 + 1) ** 0.5) return log(x + (x ** 2 + 1) ** 0.5)
def acosh(x): def acosh(x):
r"""Element-wise inverse hyperbolic cosine.""" r"""Element-wise `inverse hyperbolic cosine`."""
return log(x + (x ** 2 - 1) ** 0.5) return log(x + (x ** 2 - 1) ** 0.5)
def atanh(x): def atanh(x):
r"""Element-wise inverse hyperbolic tangent.""" r"""Element-wise `inverse hyperbolic tangent`."""
return log1p(2 * x / (1 - x)) / 2 return log1p(2 * x / (1 - x)) / 2
def fast_tanh(x): def fast_tanh(x):
r"""Element-wise fast tanh; this is an approximation: r"""Element-wise `fast tanh`; this is an approximation:
.. math:: .. math::
\text{fast_tanh}(x) = x * (27. + x * x) / (27. + 9. * x * x) \text{fast_tanh}(x) = x * (27. + x * x) / (27. + 9. * x * x)
...@@ -381,7 +382,7 @@ def fast_tanh(x): ...@@ -381,7 +382,7 @@ def fast_tanh(x):
def left_shift(x, y): def left_shift(x, y):
"""Element-wise bitwise binary: x << y. """Element-wise `bitwise binary: x << y`.
:param x: input tensor, should be int. :param x: input tensor, should be int.
:param y: how many bits to be left-shifted. :param y: how many bits to be left-shifted.
...@@ -411,7 +412,7 @@ def left_shift(x, y): ...@@ -411,7 +412,7 @@ def left_shift(x, y):
def right_shift(x, y): def right_shift(x, y):
"""Element-wise bitwise binary: x >> y.""" """Element-wise `bitwise binary: x >> y`."""
return _elwise(x, y, mode="shr") return _elwise(x, y, mode="shr")
...@@ -419,22 +420,22 @@ def right_shift(x, y): ...@@ -419,22 +420,22 @@ def right_shift(x, y):
def logical_and(x, y): def logical_and(x, y):
"""Element-wise logical and: x && y.""" """Element-wise `logical and: x && y`."""
return _elwise(x, y, mode="AND") return _elwise(x, y, mode="AND")
def logical_not(x): def logical_not(x):
"""Element-wise logical not: ~x.""" """Element-wise `logical not: ~x`."""
return _elwise(x, mode="NOT") return _elwise(x, mode="NOT")
def logical_or(x, y): def logical_or(x, y):
"""Element-wise logical or: x || y.""" """Element-wise `logical or: x || y`."""
return _elwise(x, y, mode="OR") return _elwise(x, y, mode="OR")
def logical_xor(x, y): def logical_xor(x, y):
"""Element-wise logical xor: x ^ y.""" """Element-wise `logical xor: x ^ y`."""
return _elwise(x, y, mode="XOR") return _elwise(x, y, mode="XOR")
...@@ -442,7 +443,7 @@ def logical_xor(x, y): ...@@ -442,7 +443,7 @@ def logical_xor(x, y):
def eq(x, y): def eq(x, y):
"""Element-wise (x == y). """Element-wise `(x == y)`.
:param x: input tensor 1. :param x: input tensor 1.
:param y: input tensor 2. :param y: input tensor 2.
...@@ -473,27 +474,27 @@ def eq(x, y): ...@@ -473,27 +474,27 @@ def eq(x, y):
def ne(x, y): def ne(x, y):
"""Element-wise (x != y).""" """Element-wise `(x != y)`."""
return x != y return x != y
def lt(x, y): def lt(x, y):
"""Element-wise (x < y).""" """Element-wise `(x < y)`."""
return _elwise(x, y, mode="lt") return _elwise(x, y, mode="lt")
def le(x, y): def le(x, y):
"""Element-wise (x <= y).""" """Element-wise `(x <= y)`."""
return _elwise(x, y, mode="leq") return _elwise(x, y, mode="leq")
def gt(x, y): def gt(x, y):
"""Element-wise (x > y).""" """Element-wise `(x > y)`."""
return _elwise(y, x, mode="lt") return _elwise(y, x, mode="lt")
def ge(x, y): def ge(x, y):
"""Element-wise (x >= y).""" """Element-wise `(x >= y)`."""
return _elwise(y, x, mode="leq") return _elwise(y, x, mode="leq")
...@@ -501,7 +502,7 @@ def ge(x, y): ...@@ -501,7 +502,7 @@ def ge(x, y):
def hswish(x): def hswish(x):
"""Element-wise x * relu6(x + 3) / 6. """Element-wise `x * relu6(x + 3) / 6`.
:param x: input tensor. :param x: input tensor.
:return: computed tensor. :return: computed tensor.
...@@ -527,7 +528,7 @@ def hswish(x): ...@@ -527,7 +528,7 @@ def hswish(x):
def hsigmoid(x): def hsigmoid(x):
"""Element-wise relu6(x + 3) / 6.""" """Element-wise `relu6(x + 3) / 6`."""
return relu6(x + 3) / 6 return relu6(x + 3) / 6
...@@ -537,12 +538,12 @@ def relu(x): ...@@ -537,12 +538,12 @@ def relu(x):
def relu6(x): def relu6(x):
"""Element-wise min(max(x, 0), 6).""" """Element-wise `min(max(x, 0), 6)`."""
return minimum(maximum(x, 0), 6) return minimum(maximum(x, 0), 6)
def sigmoid(x): def sigmoid(x):
"""Element-wise 1 / ( 1 + exp( -x ) ).""" """Element-wise `1 / ( 1 + exp( -x ) )`."""
return _elwise(x, mode="sigmoid") return _elwise(x, mode="sigmoid")
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=too-many-lines
from typing import List
from ..tensor import Tensor
def cambricon_subgraph(
inputs: List[Tensor], data: bytes, symbol: str, tensor_dim_mutable: bool,
) -> List[Tensor]:
"""Loads a serialized Cambricon subgraph (i.e. cnrtModel_t) and
execute the operations defined in the subgraph.
:param inputs: list of input tensors of the subgraph.
:param data: the serialized subgraph.
:param symbol: the name of the function in the subgraph.
The function is corresponding to a cnmlFusionOp
which is added to the cnmlModel_t/cnrtModel_t.
:param tensor_dim_mutable: whether the input tensors' shapes are mutalbe
in cnrtModel_t.
"""
raise NotImplementedError
def extern_opr_subgraph(
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes,
) -> List[Tensor]:
"""Loads a serialized extern opr subgraph and fake execute the operator.
:param inputs: tensor or list of input tensors.
:param output_shapes: the output shapes.
:param dump_name: the serialized subgraph name.
:param dump_data: the serialized subgraph.
:return: list of tensors.
"""
raise NotImplementedError
...@@ -132,7 +132,7 @@ def cross_entropy_with_softmax( ...@@ -132,7 +132,7 @@ def cross_entropy_with_softmax(
.. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K .. math:: y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K
where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively. where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively.
k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes. k is the index of label distribution. :math:`\alpha` is ``label_smooth`` and :math:`K` is the number of classes.
:param pred: input tensor representing the predicted probability. :param pred: input tensor representing the predicted probability.
:param label: input tensor representing the classification label. :param label: input tensor representing the classification label.
...@@ -188,7 +188,7 @@ def cross_entropy_with_softmax( ...@@ -188,7 +188,7 @@ def cross_entropy_with_softmax(
def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
r"""Function that measures the Binary Cross Entropy between the target and the prediction. r"""Function that measures the Binary Cross Entropy between the target and the prediction.
:param pred: `(N, *)` where `*` means any number of additional dimensions. :param pred: `(N, *)`, where `*` means any number of additional dimensions.
:param label: `(N, *)`, same shape as the input. :param label: `(N, *)`, same shape as the input.
:return: loss value. :return: loss value.
...@@ -216,7 +216,7 @@ def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: ...@@ -216,7 +216,7 @@ def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
r"""Caculate the hinge loss which is often used in SVMs. r"""Caculates the hinge loss which is often used in SVM.
The hinge loss can be described as: The hinge loss can be described as:
......
...@@ -46,7 +46,7 @@ def isnan(inp: Tensor) -> Tensor: ...@@ -46,7 +46,7 @@ def isnan(inp: Tensor) -> Tensor:
r"""Returns a new tensor representing if each element is ``NaN`` or not. r"""Returns a new tensor representing if each element is ``NaN`` or not.
:param inp: input tensor. :param inp: input tensor.
:return: a new tensor representing if each element in inp is NaN or not. :return: result tensor.
Examples: Examples:
...@@ -72,7 +72,7 @@ def isinf(inp: Tensor) -> Tensor: ...@@ -72,7 +72,7 @@ def isinf(inp: Tensor) -> Tensor:
r"""Returns a new tensor representing if each element is ``Inf`` or not. r"""Returns a new tensor representing if each element is ``Inf`` or not.
:param inp: input tensor. :param inp: input tensor.
:return: a new tensor representing if each element in inp is Inf or not. :return: c.
Examples: Examples:
...@@ -129,7 +129,7 @@ def sum( ...@@ -129,7 +129,7 @@ def sum(
reduce over all of them. reduce over all of them.
:param inp: input tensor. :param inp: input tensor.
:param axis: dimension to reduce. If None, all the dimensions will be reduced. :param axis: dimension to reduce. If None, all dimensions will be reduced.
Default: None Default: None
:param keepdims: whether the output tensor has axis retained or not. :param keepdims: whether the output tensor has axis retained or not.
Default: False Default: False
...@@ -164,7 +164,7 @@ def prod( ...@@ -164,7 +164,7 @@ def prod(
reduce over all of them. reduce over all of them.
:param inp: input tensor. :param inp: input tensor.
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
:param keepdims: whether the output tensor has axis retained or not. Default: False :param keepdims: whether the output tensor has axis retained or not. Default: False
:return: output tensor. :return: output tensor.
...@@ -200,7 +200,7 @@ def mean( ...@@ -200,7 +200,7 @@ def mean(
reduce over all of them. reduce over all of them.
:param inp: input tensor. :param inp: input tensor.
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
:param keepdims: whether the output tensor has axis retained or not. Default: False :param keepdims: whether the output tensor has axis retained or not. Default: False
:return: output tensor. :return: output tensor.
...@@ -236,7 +236,7 @@ def var( ...@@ -236,7 +236,7 @@ def var(
reduce over all of them. reduce over all of them.
:param inp: input tensor. :param inp: input tensor.
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
:param keepdims: whether the output tensor has axis retained or not. Default: False :param keepdims: whether the output tensor has axis retained or not. Default: False
:return: output tensor. :return: output tensor.
...@@ -276,7 +276,7 @@ def std( ...@@ -276,7 +276,7 @@ def std(
reduce over all of them. reduce over all of them.
:param inp: input tensor. :param inp: input tensor.
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
:param keepdims: whether the output tensor has axis retained or not. Default: False :param keepdims: whether the output tensor has axis retained or not. Default: False
:return: output tensor. :return: output tensor.
...@@ -311,7 +311,7 @@ def min( ...@@ -311,7 +311,7 @@ def min(
reduce over all of them. reduce over all of them.
:param inp: input tensor. :param inp: input tensor.
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
:param keepdims: whether the output tensor has axis retained or not. Default: False :param keepdims: whether the output tensor has axis retained or not. Default: False
:return: output tensor. :return: output tensor.
...@@ -347,7 +347,7 @@ def max( ...@@ -347,7 +347,7 @@ def max(
reduce over all of them. reduce over all of them.
:param inp: input tensor. :param inp: input tensor.
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
:param keepdims: whether the output tensor has axis retained or not. Default: False :param keepdims: whether the output tensor has axis retained or not. Default: False
:return: output tensor. :return: output tensor.
...@@ -427,7 +427,7 @@ def argmin( ...@@ -427,7 +427,7 @@ def argmin(
reduce over all of them. reduce over all of them.
:param inp: input tensor. :param inp: input tensor.
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
:param keepdims: whether the output tensor has axis retained or not. Default: False :param keepdims: whether the output tensor has axis retained or not. Default: False
:return: output tensor. :return: output tensor.
...@@ -485,7 +485,7 @@ def argmax( ...@@ -485,7 +485,7 @@ def argmax(
reduce over all of them. reduce over all of them.
:param inp: input tensor. :param inp: input tensor.
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
:param keepdims: whether the output tensor has axis retained or not. Default: False :param keepdims: whether the output tensor has axis retained or not. Default: False
:return: output tensor. :return: output tensor.
...@@ -543,15 +543,15 @@ def normalize( ...@@ -543,15 +543,15 @@ def normalize(
given axis. If axis is a list of dimensions, given axis. If axis is a list of dimensions,
reduce over all of them. reduce over all of them.
For a tensor inp of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
:math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as: :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
.. math:: .. math::
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
:param inp: input tensor. :param inp: input tensor.
:param p: power of value applied to inp. Default: 2 :param p: power of value applied to input tensor. Default: 2
:param axis: dimension to reduce. If None, all the dimensions will be reduced :param axis: dimension to reduce. If None, all dimensions will be reduced
to calculate the norm. Default: None to calculate the norm. Default: None
:param eps: a small value to avoid division by zero. Default: 1e-12 :param eps: a small value to avoid division by zero. Default: 1e-12
:return: normalized output tensor. :return: normalized output tensor.
...@@ -563,11 +563,11 @@ def normalize( ...@@ -563,11 +563,11 @@ def normalize(
def argsort(inp: Tensor, descending: bool = False) -> Tensor: def argsort(inp: Tensor, descending: bool = False) -> Tensor:
r"""Sorts the target 2d matrix by row, return both the sorted tensor and indices. r"""Returns the indices that would sort the input tensor.
:param inp: input tensor, if 2d, each row will be sorted. :param inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
:param descending: Sort in descending order, where the largest comes first. Default: False :param descending: sort in descending order, where the largest comes first. Default: False
:return: Tuple of two tensors `(sorted_tensor, indices_of_int32)`. :return: indices of int32 indicates how to sort the input.
Examples: Examples:
...@@ -604,6 +604,31 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor: ...@@ -604,6 +604,31 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor:
def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
r"""Returns sorted tensor and the indices would sort the input tensor.
:param inp: input tensor. If it's 2d, the result would be sorted by row.
:param descending: sort in descending order, where the largest comes first. Default: False
:return: tuple of two tensors `(sorted_tensor, indices_of_int32)`.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.array([1,2], dtype=np.float32))
out, indices = F.sort(x)
print(out.numpy())
Outputs:
.. testoutput::
[1. 2.]
"""
assert len(inp.shape) <= 2, "Input should be 1d or 2d" assert len(inp.shape) <= 2, "Input should be 1d or 2d"
if descending: if descending:
order = P.Argsort.Order.DESCENDING order = P.Argsort.Order.DESCENDING
...@@ -626,13 +651,13 @@ def topk( ...@@ -626,13 +651,13 @@ def topk(
kth_only: bool = False, kth_only: bool = False,
no_sort: bool = False, no_sort: bool = False,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
r"""Selects the ``Top-K(by default)`` smallest elements of 2d matrix by row. r"""Selects the ``Top-K``(by default) smallest elements of 2d matrix by row.
:param inp: input tensor, if 2d, each row will be sorted. :param inp: input tensor. If input tensor is 2d, each row will be sorted.
:param k: number of elements needed. :param k: number of elements needed.
:param descending: if true, return the largest elements instead. Default: False :param descending: if True, return the largest elements instead. Default: False
:param kth_only: if true, only the k-th element will be returned. Default: False :param kth_only: if True, only the k-th element will be returned. Default: False
:param no_sort: if true, the returned elements can be unordered. Default: False :param no_sort: if True, the returned elements can be unordered. Default: False
:return: tuple of two tensors `(topk_tensor, indices_of_int32)`. :return: tuple of two tensors `(topk_tensor, indices_of_int32)`.
Examples: Examples:
......
...@@ -107,19 +107,18 @@ def conv2d( ...@@ -107,19 +107,18 @@ def conv2d(
:param padding: size of the paddings added to the input on both sides of its :param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1 :param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into, :param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
so as to perform a ``grouped convolution``. When groups is not 1, ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
in_channels and out_channels must be divisible by groups,
and the shape of weight should be `(groups, out_channel // groups, and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`. in_channels // groups, height, width)`.
:type conv_mode: string or :class:`P.Convolution.Mode`. :type conv_mode: string or :class:`P.Convolution.Mode`
:param conv_mode: supports "CROSS_CORRELATION" or "CONVOLUTION". Default: :param conv_mode: supports "CROSS_CORRELATION" or "CONVOLUTION". Default:
"CROSS_CORRELATION" "CROSS_CORRELATION"
:type compute_mode: string or :type compute_mode: string or
:class:`P.Convolution.ComputeMode`. :class:`P.Convolution.ComputeMode`
:param compute_mode: when set to "DEFAULT", no special requirements will be :param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32", placed on the precision of intermediate results. When set to "FLOAT32",
Float32 would be used for accumulator and intermediate result, but only "Float32" would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype. effective when input and output are of Float16 dtype.
:return: output tensor. :return: output tensor.
""" """
...@@ -168,24 +167,23 @@ def conv_transpose2d( ...@@ -168,24 +167,23 @@ def conv_transpose2d(
:param inp: feature map of the convolution operation. :param inp: feature map of the convolution operation.
:param weight: convolution kernel. :param weight: convolution kernel.
:param bias: bias added to the result of convolution (if given) :param bias: bias added to the result of convolution (if given).
:param stride: stride of the 2D convolution operation. Default: 1 :param stride: stride of the 2D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides of its :param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1 :param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into, :param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
so as to perform a ``grouped convolution``. When groups is not 1, ``in_channels`` and ``out_channels`` must be divisible by groups,
in_channels and out_channels must be divisible by groups,
and the shape of weight should be `(groups, out_channel // groups, and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`. Default: 1 in_channels // groups, height, width)`. Default: 1
:type conv_mode: string or :class:`P.Convolution.Mode`. :type conv_mode: string or :class:`P.Convolution.Mode`
:param conv_mode: supports "CROSS_CORRELATION" or "CONVOLUTION". Default: :param conv_mode: supports "CROSS_CORRELATION" or "CONVOLUTION". Default:
"CROSS_CORRELATION" "CROSS_CORRELATION"
:type compute_mode: string or :type compute_mode: string or
:class:`P.Convolution.ComputeMode`. :class:`P.Convolution.ComputeMode`
:param compute_mode: when set to "DEFAULT", no special requirements will be :param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32", placed on the precision of intermediate results. When set to "FLOAT32",
Float32 would be used for accumulator and intermediate result, but only "Float32" would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype. effective when input and output are of Float16 dtype.
:return: output tensor. :return: output tensor.
""" """
...@@ -224,7 +222,7 @@ def local_conv2d( ...@@ -224,7 +222,7 @@ def local_conv2d(
dilation: Union[int, Tuple[int, int]] = 1, dilation: Union[int, Tuple[int, int]] = 1,
conv_mode="CROSS_CORRELATION", conv_mode="CROSS_CORRELATION",
) -> Tensor: ) -> Tensor:
"""Applies spatial 2D convolution over an image with untied kernels. """Applies spatial 2D convolution over an image with unshared kernels.
Refer to :class:`~.LocalConv2d` for more information. Refer to :class:`~.LocalConv2d` for more information.
""" """
...@@ -264,7 +262,7 @@ def max_pool2d( ...@@ -264,7 +262,7 @@ def max_pool2d(
:param kernel_size: size of the window. :param kernel_size: size of the window.
:param stride: stride of the window. If not provided, its value is set to kernel_size. :param stride: stride of the window. If not provided, its value is set to kernel_size.
Default: None Default: None
:param padding: implicit zero padding to be added on both sides. Default: 0 :param padding: implicit zero padding added on both sides. Default: 0
:return: output tensor. :return: output tensor.
""" """
if stride is None: if stride is None:
...@@ -293,15 +291,15 @@ def avg_pool2d( ...@@ -293,15 +291,15 @@ def avg_pool2d(
padding: Union[int, Tuple[int, int]] = 0, padding: Union[int, Tuple[int, int]] = 0,
mode: str = "AVERAGE_COUNT_EXCLUDE_PADDING", mode: str = "AVERAGE_COUNT_EXCLUDE_PADDING",
) -> Tensor: ) -> Tensor:
"""Applies a 2D average pooling over an input tensor. """Applies 2D average pooling over an input tensor.
Refer to :class:`~.AvgPool2d` for more information. Refer to :class:`~.AvgPool2d` for more information.
:param inp: input tensor. :param inp: input tensor.
:param kernel_size: size of the window. :param kernel_size: size of the window.
:param stride: stride of the window. If not provided, its value is set to kernel_size. :param stride: stride of the window. If not provided, its value is set to ``kernel_size``.
Default: None Default: None
:param padding: implicit zero padding to be added on both sides. Default: 0 :param padding: implicit zero padding added on both sides. Default: 0
:param mode: whether to count padding values. Default: "AVERAGE_COUNT_EXCLUDE_PADDING" :param mode: whether to count padding values. Default: "AVERAGE_COUNT_EXCLUDE_PADDING"
:return: output tensor. :return: output tensor.
""" """
...@@ -349,7 +347,7 @@ def softplus(inp: Tensor) -> Tensor: ...@@ -349,7 +347,7 @@ def softplus(inp: Tensor) -> Tensor:
\text{softplus}(x) = \log(1 + \exp(x)) \text{softplus}(x) = \log(1 + \exp(x))
softplus is a smooth approximation to the ReLU function and can be used softplus is a smooth approximation to the ReLU function and can be used
to constrain the output of a machine to always be positive. to constrain the output to be always positive.
For numerical stability the implementation follows this transformation: For numerical stability the implementation follows this transformation:
.. math:: .. math::
...@@ -357,7 +355,7 @@ def softplus(inp: Tensor) -> Tensor: ...@@ -357,7 +355,7 @@ def softplus(inp: Tensor) -> Tensor:
= \log(1 + \exp(-\text{abs}(x))) + \max(x, 0) = \log(1 + \exp(-\text{abs}(x))) + \max(x, 0)
= \log1p(\exp(-\text{abs}(x))) + \text{relu}(x) = \log1p(\exp(-\text{abs}(x))) + \text{relu}(x)
:param inp: The input tensor :param inp: input tensor.
Examples: Examples:
...@@ -396,8 +394,8 @@ def log_softmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: ...@@ -396,8 +394,8 @@ def log_softmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
= x - \log (\sum_{i}(\exp (x_{i}))) = x - \log (\sum_{i}(\exp (x_{i})))
= x - logsumexp(x) = x - logsumexp(x)
:param inp: The input tensor :param inp: input tensor.
:param axis: An axis along which log_softmax will be applied. :param axis: axis along which log_softmax will be applied.
Examples: Examples:
...@@ -431,7 +429,7 @@ def logsigmoid(inp: Tensor) -> Tensor: ...@@ -431,7 +429,7 @@ def logsigmoid(inp: Tensor) -> Tensor:
= - \log(1 + exp(-x)) = - \log(1 + exp(-x))
= - \text{softplus}(-x) = - \text{softplus}(-x)
:param inp: The input tensor :param inp: input tensor.
Examples: Examples:
...@@ -460,8 +458,7 @@ def logsumexp( ...@@ -460,8 +458,7 @@ def logsumexp(
inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False
) -> Tensor: ) -> Tensor:
r""" r"""
Compute the log of the sum of exponentials of inputs along the given :attr:`axis`. Calculates the logarithm of the inputs' exponential sum along the given :attr:`axis`.
The computation is numerically stabilized.
.. math:: .. math::
...@@ -479,8 +476,8 @@ def logsumexp( ...@@ -479,8 +476,8 @@ def logsumexp(
.. math:: .. math::
b = \max(x_j) b = \max(x_j)
:param inp: The input tensor. :param inp: input tensor.
:param axis: Axis over which the sum is taken. It can be a single axis or a list of axes. :param axis: axis over which the sum is taken. It could be single axis or list of axes.
:param keepdims: whether to retain :attr:`axis` or not for the output tensor. :param keepdims: whether to retain :attr:`axis` or not for the output tensor.
Examples: Examples:
...@@ -524,13 +521,13 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: ...@@ -524,13 +521,13 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
.. math:: .. math::
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
It is applied to all elements along axis, and will re-scale them so that It is applied to all elements along axis, and rescales elements so that
the elements lie in the range `[0, 1]` and sum to 1. they stay in the range `[0, 1]` and sum to 1.
See :class:`~megengine.module.activation.Softmax` for more details. See :class:`~megengine.module.activation.Softmax` for more details.
:param inp: The input tensor. :param inp: input tensor.
:param axis: An axis along which softmax will be applied. By default, :param axis: an axis along which softmax will be applied. By default,
softmax will apply along the highest ranked axis. softmax will apply along the highest ranked axis.
Examples: Examples:
...@@ -573,7 +570,7 @@ def batch_norm2d( ...@@ -573,7 +570,7 @@ def batch_norm2d(
eps: float = 1e-5, eps: float = 1e-5,
inplace: bool = True inplace: bool = True
): ):
"""Applies batch normalization to the input. r"""Applies batch normalization to the input.
Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
...@@ -585,13 +582,13 @@ def batch_norm2d( ...@@ -585,13 +582,13 @@ def batch_norm2d(
:param bias: bias tensor in the learnable affine parameters. :param bias: bias tensor in the learnable affine parameters.
See :math:`\beta` in :class:`~.BatchNorm2d`. See :math:`\beta` in :class:`~.BatchNorm2d`.
:param training: a boolean value to indicate whether batch norm is performed :param training: a boolean value to indicate whether batch norm is performed
in traning mode. Default: False in training mode. Default: False
:param momentum: value used for the ``running_mean`` and ``running_var`` :param momentum: value used for the ``running_mean`` and ``running_var``
computation. computation.
Default: 0.9 Default: 0.9
:param eps: a value added to the denominator for numerical stability. :param eps: a value added to the denominator for numerical stability.
Default: 1e-5 Default: 1e-5
:param inplace: whether to update running_mean and running_var inplace or return new tensors :param inplace: whether to update ``running_mean`` and ``running_var`` inplace or return new tensors
Default: True Default: True
:return: output tensor. :return: output tensor.
""" """
...@@ -677,7 +674,7 @@ def sync_batch_norm( ...@@ -677,7 +674,7 @@ def sync_batch_norm(
eps_mode="ADDITIVE", eps_mode="ADDITIVE",
group=WORLD, group=WORLD,
) -> Tensor: ) -> Tensor:
"""Applies synchronized batch normalization to the input. r"""Applies synchronized batch normalization to the input.
Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
...@@ -887,19 +884,18 @@ def matmul( ...@@ -887,19 +884,18 @@ def matmul(
With different inputs dim, this function behaves differently: With different inputs dim, this function behaves differently:
- Both 1-D tensor, simply forward to dot. - Both 1-D tensor, simply forward to ``dot``.
- Both 2-D tensor, normal matrix multiplication. - Both 2-D tensor, normal matrix multiplication.
- If one input tensor is 1-D, matrix vector multiplication. - If one input tensor is 1-D, matrix vector multiplication.
- If at least one tensor are 3-dimensional or >3-dimensional, the batched matrix-matrix is returned, and the tensor with smaller dimension will - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will
be broadcasted. For example: be broadcasted. For example:
- inp1: `(k, m)`, inp2: `(m, p)`, return: `(k, p)`
- inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
- inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
- inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)` - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
:param inp1: The first matrix to be multiplied :param inp1: first matrix to be multiplied.
:param inp2: The second matrix to be multiplied :param inp2: second matrix to be multiplied.
:return: The output tensor :return: output tensor.
Examples: Examples:
...@@ -983,12 +979,12 @@ def matmul( ...@@ -983,12 +979,12 @@ def matmul(
def dot(inp1: Tensor, inp2: Tensor) -> Tensor: def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
""" """
Compute dot-product of two vectors ``inp1`` and ``inp2``. Computes dot-product of two vectors ``inp1`` and ``inp2``.
inputs must be 1-dimensional, scalar input can be automatically broadcasted. inputs must be 1-dimensional, scalar input can be automatically broadcasted.
:param inp1: The first vector :param inp1: first vector.
:param inp2: The second vector :param inp2: second vector.
:return: The output value :return: output value.
Examples: Examples:
...@@ -1018,10 +1014,10 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: ...@@ -1018,10 +1014,10 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
""" """
Compute the singular value decompositions of input matrix ``inp``. Computes the singular value decompositions of input matrix.
:param inp: The input matrix, must has shape ``[..., M, N]`` :param inp: input matrix, must has shape `[..., M, N]`.
:return: The output matrices, U, sigma, V :return: output matrices, `(U, sigma, V)`.
Examples: Examples:
...@@ -1054,8 +1050,7 @@ def interpolate( ...@@ -1054,8 +1050,7 @@ def interpolate(
mode: str = "BILINEAR", mode: str = "BILINEAR",
align_corners: bool = None, align_corners: bool = None,
) -> Tensor: ) -> Tensor:
r"""Down/up samples the input tensor to either the given size or the given r"""Down/up samples the input tensor to either the given size or with the given scale_factor. ``size`` can not coexist with ``scale_factor``.
scale_factor.
:param inp: input tensor. :param inp: input tensor.
:param size: size of the output tensor. Default: None :param size: size of the output tensor. Default: None
...@@ -1198,12 +1193,12 @@ def interpolate( ...@@ -1198,12 +1193,12 @@ def interpolate(
def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
"""Returns a new tensor where each of the elements are randomly set to zero """Returns a new tensor where each of the elements are randomly set to zero
with probability P = ``drop_prob``. Optionally rescale the output tensor. with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True.
:param inp: input tensor. :param inp: input tensor.
:param drop_prob: probability to drop (set to zero) a single element. :param drop_prob: probability to drop (set to zero) a single element.
:param training: the default behavior of ``dropout`` during training is to rescale the output, :param training: the default behavior of ``dropout`` during training is to rescale the output,
then it can be replaced by an :class:`~.Identity` during inference, default to True. then it can be replaced by an :class:`~.Identity` during inference. Default: True
:return: the output tensor :return: the output tensor
Examples: Examples:
...@@ -1245,10 +1240,10 @@ def embedding( ...@@ -1245,10 +1240,10 @@ def embedding(
"""Applies lookup table for embedding. """Applies lookup table for embedding.
:param inp: tensor with indices. :param inp: tensor with indices.
:param weight: learnable weights which embedding from. :param weight: learnable weights which embeds from.
:param padding_idx: should be set to None, not support now. :param padding_idx: should be set to None, not supported now.
:param max_norm: should be set to None, not support now. :param max_norm: should be set to None, not supported now.
:param norm_type: should be set to None, not support now. :param norm_type: should be set to None, not supported now.
:return: output tensor. :return: output tensor.
Refer to :class:`~.Embedding` for more information. Refer to :class:`~.Embedding` for more information.
...@@ -1324,14 +1319,14 @@ def roi_align( ...@@ -1324,14 +1319,14 @@ def roi_align(
) -> Tensor: ) -> Tensor:
"""Applies roi align on input feature. """Applies roi align on input feature.
:param inp: tensor that represents the input feature, `(N, C, H, W)` images. :param inp: tensor that represents the input feature, shape is `(N, C, H, W)`.
:param rois: `(N, 5)` boxes. First column is the index into N. The other 4 columns are xyxy. :param rois: `(N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``.
:param output_shape: `(height, width)` shape of output rois feature. :param output_shape: `(height, width)` shape of output rois feature.
:param mode: "max" or "average", use max/average align just like max/average pooling. Default: "average" :param mode: "max" or "average", use max/average align just like max/average pooling. Default: "average"
:param spatial_scale: scale the input boxes by this number. Default: 1.0 :param spatial_scale: scale the input boxes by this number. Default: 1.0
:param sample_points: number of inputs samples to take for each output sample. :param sample_points: number of inputs samples to take for each output sample.
0 to take samples densely. Default: 2 0 to take samples densely. Default: 2
:param aligned: wheather align the input feature, with `aligned=True`, :param aligned: wheather to align the input feature, with `aligned=True`,
we first appropriately scale the ROI and then shift it by -0.5. Default: True we first appropriately scale the ROI and then shift it by -0.5. Default: True
:return: output tensor. :return: output tensor.
...@@ -1384,7 +1379,7 @@ def roi_align( ...@@ -1384,7 +1379,7 @@ def roi_align(
def indexing_one_hot( def indexing_one_hot(
src: Tensor, index: Tensor, axis: int = 1, keepdims=False src: Tensor, index: Tensor, axis: int = 1, keepdims=False
) -> Tensor: ) -> Tensor:
r"""One-hot indexing for some axis. r"""One-hot indexing for some axes.
:param src: input tensor. :param src: input tensor.
:param index: index tensor. :param index: index tensor.
...@@ -1427,7 +1422,7 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor: ...@@ -1427,7 +1422,7 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor:
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU). Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU).
:param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format. :param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format.
:param iou_thresh: iou threshold for overlapping. :param iou_thresh: IoU threshold for overlapping.
:param scores: tensor of shape `(N,)`, the score of boxes. :param scores: tensor of shape `(N,)`, the score of boxes.
:return: indices of the elements that have been kept by NMS. :return: indices of the elements that have been kept by NMS.
...@@ -1483,11 +1478,11 @@ def batched_nms( ...@@ -1483,11 +1478,11 @@ def batched_nms(
r""" r"""
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU). Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU).
:param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format :param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format.
:param iou_thresh: iou threshold for overlapping :param iou_thresh: ``IoU`` threshold for overlapping.
:param idxs: tensor of shape `(N,)`, the class indexs of boxes in the batch. :param idxs: tensor of shape `(N,)`, the class indexs of boxes in the batch.
:param scores: tensor of shape `(N,)`, the score of boxes. :param scores: tensor of shape `(N,)`, the score of boxes.
:return: indices and the number of the elements that have been kept by NMS :return: indices of the elements that have been kept by NMS.
Examples: Examples:
......
...@@ -34,26 +34,23 @@ def conv_bias_activation( ...@@ -34,26 +34,23 @@ def conv_bias_activation(
:param weight: convolution kernel. :param weight: convolution kernel.
:param bias: bias added to the result of convolution :param bias: bias added to the result of convolution
:param stride: stride of the 2D convolution operation. Default: 1 :param stride: stride of the 2D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides of its :param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1 :param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into, :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
so as to perform a "grouped convolution". When groups is not 1, ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
in_channels and out_channels must be divisible by groups,
and the shape of weight should be `(groups, out_channel // groups, and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`. in_channels // groups, height, width)`.
:type conv_mode: string or :class:`P.Convolution.Mode`. :type conv_mode: string or :class:`P.Convolution.Mode`.
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION' 'CROSS_CORRELATION'
:param dtype: support for np.dtype, Default: np.int8 :param dtype: support for ``np.dtype``, Default: np.int8
:param scale: scale if use quantization, Default: 0.0 :param scale: scale if use quantization, Default: 0.0
:param zero_point: scale if use quantization quint8, Default: 0.0 :param zero_point: scale if use quantization quint8, Default: 0.0
:type compute_mode: string or :type compute_mode: string or
:class:`P.Convolution.ComputeMode`. :class:`P.Convolution.ComputeMode`.
:param compute_mode: when set to 'DEFAULT', no special requirements will be :param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to 'FLOAT32', placed on the precision of intermediate results. When set to "FLOAT32",
Float32 would be used for accumulator and intermediate result, but only "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype.
effective when input and output are of Float16 dtype.
""" """
ph, pw = _pair(padding) ph, pw = _pair(padding)
......
...@@ -52,6 +52,7 @@ __all__ = [ ...@@ -52,6 +52,7 @@ __all__ = [
"reshape", "reshape",
"remove_axis", "remove_axis",
"split", "split",
"squeeze",
"stack", "stack",
"scatter", "scatter",
"transpose", "transpose",
...@@ -64,8 +65,7 @@ __all__ = [ ...@@ -64,8 +65,7 @@ __all__ = [
def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere. """Returns a 2D tensor with ones on the diagonal and zeros elsewhere.
:param shape: expected shape of otuput tensor. :param shape: expected shape of output tensor.
:param m: number of columns. Default: None
:param dtype: data type. Default: None :param dtype: data type. Default: None
:param device: compute node of the matrix. Default: None :param device: compute node of the matrix. Default: None
:return: eye matrix. :return: eye matrix.
...@@ -171,7 +171,7 @@ def zeros_like(inp: Tensor) -> Tensor: ...@@ -171,7 +171,7 @@ def zeros_like(inp: Tensor) -> Tensor:
def ones_like(inp: Tensor) -> Tensor: def ones_like(inp: Tensor) -> Tensor:
"""Returns a identity tensor with the same shape as input tensor. """Returns a ones tensor with the same shape as input tensor.
""" """
return ones(inp.shape, dtype=inp.dtype, device=inp.device) return ones(inp.shape, dtype=inp.dtype, device=inp.device)
...@@ -183,7 +183,7 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: ...@@ -183,7 +183,7 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
def identity(inp: Tensor) -> Tensor: def identity(inp: Tensor) -> Tensor:
"""Applies an identity transform to the input tensor. """Applies an identity transformation to input tensor.
:param inp: input tensor. :param inp: input tensor.
:return: output tensor. :return: output tensor.
...@@ -239,8 +239,8 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: ...@@ -239,8 +239,8 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
Concat some tensors Concat some tensors
:param inps: input tensors to concat. :param inps: input tensors to concat.
:param axis: dimension over which the tensors are concatenated. Default: 0 :param axis: over which dimension the tensors are concatenated. Default: 0
:param device: comp node output on. Default: None :param device: which device output will be. Default: None
:return: output tensor. :return: output tensor.
Examples: Examples:
...@@ -288,7 +288,7 @@ def stack(inps, axis=0, device=None): ...@@ -288,7 +288,7 @@ def stack(inps, axis=0, device=None):
:param inps: input tensors. :param inps: input tensors.
:param axis: which axis will be concatenated. :param axis: which axis will be concatenated.
:param device: The comp node output on. Default: None :param device: the device output will be. Default: None
:return: output concatenated tensor. :return: output concatenated tensor.
Examples: Examples:
...@@ -329,7 +329,7 @@ def split(inp, nsplits_or_sections, axis=0): ...@@ -329,7 +329,7 @@ def split(inp, nsplits_or_sections, axis=0):
When nsplits_or_sections is int, the last tensor may be smaller than others. When nsplits_or_sections is int, the last tensor may be smaller than others.
:param inp: input tensor. :param inp: input tensor.
:param nsplits_or_sections: number of sub tensors or section information list. :param nsplits_or_sections: number of sub tensors or sections information list.
:param axis: which axis will be splited. :param axis: which axis will be splited.
:return: output tensor list. :return: output tensor list.
...@@ -409,7 +409,7 @@ def _get_idx(index, axis): ...@@ -409,7 +409,7 @@ def _get_idx(index, axis):
def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
r"""Gathers data from inp on axis using index. r"""Gathers data from input tensor on axis using index.
For a 3-D tensor, the output is specified by:: For a 3-D tensor, the output is specified by::
...@@ -417,14 +417,14 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: ...@@ -417,14 +417,14 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1 out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2 out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
if inp is an n-dimensional tensor with size if input tensor is a n-dimensional tensor with size
:math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i, :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
then index must be an n-dimensional tensor with size then index must be a n-dimensional tensor with size
:math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
output will have the same size as index. output will have the same size as index.
:param inp: input tensor. :param inp: input tensor.
:param axis: axis along which to index. :param axis: along which axis to index.
:param index: indices of elements to gather. :param index: indices of elements to gather.
:return: output tensor. :return: output tensor.
...@@ -480,20 +480,20 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: ...@@ -480,20 +480,20 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
r"""Writes all values from the tensor source into inp r"""Writes all values from the tensor source into input tensor
at the indices specified in the index tensor. at the indices specified in the index tensor.
For each value in source, its output index is specified by its index For each value in source, its output index is specified by its index
in source for ``axis != dimension`` and by the corresponding value in in source for ``axis != dimension`` and by the corresponding value in
index for ``axis = dimension``. index for ``axis = dimension``.
For a 3-D tensor, inp is updated as:: For a 3-D tensor, input tensor is updated as::
inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0 inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0
inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1 inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1
inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2 inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2
inp, index and source should have same number of dimensions. ``inp``, ``index`` and ``source`` should have same number of dimensions.
It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)`` It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
for all dimensions ``d``. for all dimensions ``d``.
...@@ -502,10 +502,10 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: ...@@ -502,10 +502,10 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
.. note:: .. note::
Please notice that, due to performance issues, the result is uncertain on the GPU device Please notice that, due to performance issues, the result is uncertain on the GPU device
if scatter difference positions from source to the same destination position if scattering different positions from source to the same destination position
regard to index tensor. regard to index tensor.
Show the case using the following examples, the oup[0][2] is maybe Check the following examples, the oup[0][2] is maybe
from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339 from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339
if set the index[1][2] from 1 to 0. if set the index[1][2] from 1 to 0.
...@@ -591,7 +591,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: ...@@ -591,7 +591,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
\textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i
:param mask: a mask used for choosing x or y. :param mask: a mask used for choosing ``x`` or ``y``.
:param x: first choice. :param x: first choice.
:param y: second choice. :param y: second choice.
:return: output tensor. :return: output tensor.
...@@ -647,7 +647,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: ...@@ -647,7 +647,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
def cond_take(mask: Tensor, x: Tensor) -> Tensor: def cond_take(mask: Tensor, x: Tensor) -> Tensor:
r""" r"""
Take elements from data if specific condition is satisfied on mask. Takes elements from data if specific condition is satisfied on mask.
This operator has two outputs: the first is the elements taken, This operator has two outputs: the first is the elements taken,
and the second is the indices corresponding to those elements; and the second is the indices corresponding to those elements;
they are both 1-dimensional. High-dimension input would first be flattened. they are both 1-dimensional. High-dimension input would first be flattened.
...@@ -705,7 +705,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: ...@@ -705,7 +705,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
* (2, 0, 1) -> AxBxC to CxAxB * (2, 0, 1) -> AxBxC to CxAxB
* (0, ``'x'``, 1) -> AxB to Ax1xB * (0, ``'x'``, 1) -> AxB to Ax1xB
* (1, ``'x'``, 0) -> AxB to Bx1xA * (1, ``'x'``, 0) -> AxB to Bx1xA
* (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) * (1,) -> this removes dimensions 0. It must be a broadcastable dimension (1xA to A)
:return: output tensor. :return: output tensor.
...@@ -743,8 +743,7 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: ...@@ -743,8 +743,7 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
remain unchanged remain unchanged
:param inp: input tensor. :param inp: input tensor.
:param target_shape: target shape, the components would be concatenated to form the :param target_shape: target shape, it can contain an element of -1 representing ``unspec_axis``.
target shape, and it can contain an element of -1 representing unspec_axis.
Examples: Examples:
...@@ -862,7 +861,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: ...@@ -862,7 +861,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
return result return result
add_axis = add_axis expand_dims = add_axis
def remove_axis( def remove_axis(
...@@ -897,6 +896,9 @@ def remove_axis( ...@@ -897,6 +896,9 @@ def remove_axis(
return _remove_axis(inp, axis) return _remove_axis(inp, axis)
squeeze = remove_axis
def linspace( def linspace(
start: Union[int, float, Tensor], start: Union[int, float, Tensor],
stop: Union[int, float, Tensor], stop: Union[int, float, Tensor],
...@@ -948,7 +950,7 @@ def arange( ...@@ -948,7 +950,7 @@ def arange(
dtype="float32", dtype="float32",
device: Optional[CompNode] = None, device: Optional[CompNode] = None,
) -> Tensor: ) -> Tensor:
r"""Returns a Tensor with values from start to end with adjacent interval step. r"""Returns a tensor with values from start to end with adjacent interval step.
:param start: starting value of the squence, shoule be scalar. :param start: starting value of the squence, shoule be scalar.
:param end: ending value of the squence, shoule be scalar. :param end: ending value of the squence, shoule be scalar.
...@@ -994,15 +996,15 @@ def arange( ...@@ -994,15 +996,15 @@ def arange(
def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor: def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor:
r""" r"""
Returns split Tensor to Tensor list as offsets and shapes described, Returns split tensor to tensor list as offsets and shapes described,
only used for parampack. only used for ``parampack``.
:param inp: input tensor. :param inp: input tensor.
:param offsets: offsets of outputs, length of 2 * n, :param offsets: offsets of outputs, length of `2 * n`,
while n is tensor nums you want to split, while n is tensor nums you want to split,
format `[begin0, end0, begin1, end1]`. format `[begin0, end0, begin1, end1]`.
:param shapes: tensor shapes of outputs. :param shapes: tensor shapes of outputs.
:return: split tensors. :return: splitted tensors.
Examples: Examples:
...@@ -1035,13 +1037,13 @@ def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor: ...@@ -1035,13 +1037,13 @@ def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor:
def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor: def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor:
r""" r"""
Returns concat Tensor, only used for parampack. Returns concated tensor, only used for ``parampack``.
:param inps: input tensors. :param inps: input tensors.
:param offsets: device value of offsets. :param offsets: device value of offsets.
:param offsets_val: offsets of inputs, length of 2 * n, :param offsets_val: offsets of inputs, length of `2 * n`,
format [begin0, end0, begin1, end1]. format `[begin0, end0, begin1, end1]`.
:return: concat tensors :return: concated tensor.
Examples: Examples:
......
...@@ -22,7 +22,7 @@ def accuracy( ...@@ -22,7 +22,7 @@ def accuracy(
logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1 logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1
) -> Union[Tensor, Iterable[Tensor]]: ) -> Union[Tensor, Iterable[Tensor]]:
r""" r"""
Calculate the classification accuracy given predicted logits and ground-truth labels. Calculates the classification accuracy given predicted logits and ground-truth labels.
:param logits: model predictions of shape `[batch_size, num_classes]`, :param logits: model predictions of shape `[batch_size, num_classes]`,
representing the probability (likelyhood) of each class. representing the probability (likelyhood) of each class.
...@@ -63,25 +63,12 @@ def accuracy( ...@@ -63,25 +63,12 @@ def accuracy(
return accs return accs
def zero_grad(inp: Tensor) -> Tensor:
r"""
Returns a tensor which is treated as constant during backward gradient calcuation,
i.e. its gradient is zero.
:param inp: Input tensor.
See implementation of :func:`~.softmax` for example.
"""
print("zero_grad is obsoleted, please use detach instead")
raise NotImplementedError
def copy(inp, cn): def copy(inp, cn):
r""" r"""
Copy tensor to another device. Copies tensor to another device.
:param inp: input tensor. :param inp: input tensor.
:param cn: device that you copy to. :param cn: destination device.
Examples: Examples:
......
...@@ -19,12 +19,12 @@ class InvalidGitHost(FetcherError): ...@@ -19,12 +19,12 @@ class InvalidGitHost(FetcherError):
class GitPullError(FetcherError): class GitPullError(FetcherError):
"""A git pull error occurred""" """A git pull error occurred."""
class GitCheckoutError(FetcherError): class GitCheckoutError(FetcherError):
"""A git checkout error occurred""" """A git checkout error occurred."""
class InvalidProtocol(FetcherError): class InvalidProtocol(FetcherError):
"""The protocol provided was somehow invalid""" """The protocol provided was somehow invalid."""
...@@ -106,20 +106,20 @@ class GitSSHFetcher(RepoFetcherBase): ...@@ -106,20 +106,20 @@ class GitSSHFetcher(RepoFetcherBase):
:param git_host: :param git_host:
host address of git repo. host address of git repo.
example: github.com Example: github.com
:param repo_info: :param repo_info:
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. tag/branch. The default branch is ``master`` if not specified.
example: ``"brain_sdk/MegBrain[:hub]"`` Example: ``"brain_sdk/MegBrain[:hub]"``
:param use_cache: :param use_cache:
whether to use locally fetched code or completely re-fetch whether to use locally fetched code or completely re-fetch.
:param commit: :param commit:
commit id on github or gitlab commit id on github or gitlab.
:param silent: :param silent:
whether to accept the stdout and stderr of the subprocess with PIPE, instead of whether to accept the stdout and stderr of the subprocess with PIPE, instead of
displaying on the screen displaying on the screen.
:return: :return:
directory where the repo code is stored directory where the repo code is stored.
""" """
if not cls._check_git_host(git_host): if not cls._check_git_host(git_host):
raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host))
...@@ -215,24 +215,24 @@ class GitHTTPSFetcher(RepoFetcherBase): ...@@ -215,24 +215,24 @@ class GitHTTPSFetcher(RepoFetcherBase):
silent: bool = True, silent: bool = True,
) -> str: ) -> str:
""" """
Fetches git repo by HTTPS protocol Fetches git repo by HTTPS protocol.
:param git_host: :param git_host:
host address of git repo host address of git repo.
example: github.com Example: github.com
:param repo_info: :param repo_info:
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. tag/branch. The default branch is ``master`` if not specified.
example: ``"brain_sdk/MegBrain[:hub]"`` Example: ``"brain_sdk/MegBrain[:hub]"``
:param use_cache: :param use_cache:
whether to use locally cached code or completely re-fetch whether to use locally cached code or completely re-fetch.
:param commit: :param commit:
commit id on github or gitlab commit id on github or gitlab.
:param silent: :param silent:
whether to accept the stdout and stderr of the subprocess with PIPE, instead of whether to accept the stdout and stderr of the subprocess with PIPE, instead of
displaying on the screen displaying on the screen.
:return: :return:
directory where the repo code is stored directory where the repo code is stored.
""" """
if not cls._check_git_host(git_host): if not cls._check_git_host(git_host):
raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host))
......
...@@ -94,24 +94,24 @@ def _init_hub( ...@@ -94,24 +94,24 @@ def _init_hub(
commit: str = None, commit: str = None,
protocol: str = DEFAULT_PROTOCOL, protocol: str = DEFAULT_PROTOCOL,
): ):
"""Imports hubmodule like python import """Imports hubmodule like python import.
:param repo_info: :param repo_info:
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. tag/branch. The default branch is ``master`` if not specified.
Example: ``"brain_sdk/MegBrain[:hub]"`` Example: ``"brain_sdk/MegBrain[:hub]"``
:param git_host: :param git_host:
host address of git repo host address of git repo.
Example: github.com Example: github.com
:param use_cache: :param use_cache:
whether to use locally cached code or completely re-fetch whether to use locally cached code or completely re-fetch.
:param commit: :param commit:
commit id on github or gitlab commit id on github or gitlab.
:param protocol: :param protocol:
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH. The value should be one of HTTPS, SSH.
:return: :return:
hubconf.py as a python module a python module.
""" """
cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
...@@ -137,24 +137,24 @@ def list( ...@@ -137,24 +137,24 @@ def list(
commit: str = None, commit: str = None,
protocol: str = DEFAULT_PROTOCOL, protocol: str = DEFAULT_PROTOCOL,
) -> List[str]: ) -> List[str]:
"""Lists all entrypoints available in repo hubconf """Lists all entrypoints available in repo hubconf.
:param repo_info: :param repo_info:
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. tag/branch. The default branch is ``master`` if not specified.
Example: ``"brain_sdk/MegBrain[:hub]"`` Example: ``"brain_sdk/MegBrain[:hub]"``
:param git_host: :param git_host:
host address of git repo host address of git repo.
Example: github.com Example: github.com
:param use_cache: :param use_cache:
whether to use locally cached code or completely re-fetch whether to use locally cached code or completely re-fetch.
:param commit: :param commit:
commit id on github or gitlab commit id on github or gitlab.
:param protocol: :param protocol:
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH. The value should be one of HTTPS, SSH.
:return: :return:
all entrypoint names of the model all entrypoint names of the model.
""" """
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
...@@ -182,14 +182,14 @@ def load( ...@@ -182,14 +182,14 @@ def load(
tag/branch. The default branch is ``master`` if not specified. tag/branch. The default branch is ``master`` if not specified.
Example: ``"brain_sdk/MegBrain[:hub]"`` Example: ``"brain_sdk/MegBrain[:hub]"``
:param entry: :param entry:
an entrypoint defined in hubconf an entrypoint defined in hubconf.
:param git_host: :param git_host:
host address of git repo host address of git repo.
Example: github.com Example: github.com
:param use_cache: :param use_cache:
whether to use locally cached code or completely re-fetch whether to use locally cached code or completely re-fetch.
:param commit: :param commit:
commit id on github or gitlab commit id on github or gitlab.
:param protocol: :param protocol:
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH. The value should be one of HTTPS, SSH.
...@@ -217,9 +217,9 @@ def help( ...@@ -217,9 +217,9 @@ def help(
) -> str: ) -> str:
"""This function returns docstring of entrypoint ``entry`` by following steps: """This function returns docstring of entrypoint ``entry`` by following steps:
1. Pull the repo code specified by git and repo_info 1. Pull the repo code specified by git and repo_info.
2. Load the entry defined in repo's hubconf.py 2. Load the entry defined in repo's hubconf.py
3. Return docstring of function entry 3. Return docstring of function entry.
:param repo_info: :param repo_info:
a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
...@@ -228,17 +228,17 @@ def help( ...@@ -228,17 +228,17 @@ def help(
:param entry: :param entry:
an entrypoint defined in hubconf.py an entrypoint defined in hubconf.py
:param git_host: :param git_host:
host address of git repo host address of git repo.
Example: github.com Example: github.com
:param use_cache: :param use_cache:
whether to use locally cached code or completely re-fetch whether to use locally cached code or completely re-fetch.
:param commit: :param commit:
commit id on github or gitlab commit id on github or gitlab.
:param protocol: :param protocol:
which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH. The value should be one of HTTPS, SSH.
:return: :return:
docstring of entrypoint ``entry`` docstring of entrypoint ``entry``.
""" """
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
...@@ -255,10 +255,10 @@ def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: ...@@ -255,10 +255,10 @@ def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
If the object is already present in ``model_dir``, it's deserialized and If the object is already present in ``model_dir``, it's deserialized and
returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``. returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``.
:param url: url to serialized object :param url: url to serialized object.
:param model_dir: dir to cache target serialized file :param model_dir: dir to cache target serialized file.
:return: loaded object :return: loaded object.
""" """
if model_dir is None: if model_dir is None:
model_dir = os.path.join(_get_megengine_home(), "serialized") model_dir = os.path.join(_get_megengine_home(), "serialized")
......
...@@ -15,10 +15,10 @@ from typing import Iterator ...@@ -15,10 +15,10 @@ from typing import Iterator
def load_module(name: str, path: str) -> types.ModuleType: def load_module(name: str, path: str) -> types.ModuleType:
""" """
Loads module specified by name and path Loads module specified by name and path.
:param name: module name :param name: module name.
:param path: module path :param path: module path.
""" """
spec = importlib.util.spec_from_file_location(name, path) spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
...@@ -27,18 +27,18 @@ def load_module(name: str, path: str) -> types.ModuleType: ...@@ -27,18 +27,18 @@ def load_module(name: str, path: str) -> types.ModuleType:
def check_module_exists(module: str) -> bool: def check_module_exists(module: str) -> bool:
"""Checks whether python module exists or not """Checks whether python module exists or not.
:param module: name of module :param module: name of module.
""" """
return importlib.util.find_spec(module) is not None return importlib.util.find_spec(module) is not None
@contextmanager @contextmanager
def cd(target: str) -> Iterator[None]: def cd(target: str) -> Iterator[None]:
"""Changes current directory to target """Changes current directory to target.
:param target: target directory :param target: target directory.
""" """
prev = os.getcwd() prev = os.getcwd()
os.chdir(os.path.expanduser(target)) os.chdir(os.path.expanduser(target))
......
...@@ -20,10 +20,10 @@ class Softmax(Module): ...@@ -20,10 +20,10 @@ class Softmax(Module):
.. math:: .. math::
\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)} \text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}
It is applied to an n-dimensional input Tensor and rescaling them so that the elements of the It is applied to all elements along axis, and rescales elements so that
n-dimensional output Tensor lie in the range of `[0, 1]` and sum to 1. they stay in the range `[0, 1]` and sum to 1.
:param axis: An axis along which softmax will be applied. By default, :param axis: Along which axis softmax will be applied. By default,
softmax will apply along the highest ranked axis. softmax will apply along the highest ranked axis.
Examples: Examples:
...@@ -141,8 +141,7 @@ class PReLU(Module): ...@@ -141,8 +141,7 @@ class PReLU(Module):
\end{cases} \end{cases}
Here :math:`a` is a learnable parameter. When called without arguments, `PReLU()` uses Here :math:`a` is a learnable parameter. When called without arguments, `PReLU()` uses
a single paramter :math:`a` across all input channel. If called with `PReLU(num_of_channels)`, a single paramter :math:`a` across all input channel. If called with `PReLU(num_of_channels)`, each input channle will has it's own :math:`a`.
a seperate :math:`a` is used for each input channle.
:param num_parameters: number of :math:`a` to learn, there is only two :param num_parameters: number of :math:`a` to learn, there is only two
values are legitimate: 1, or the number of channels at input. Default: 1 values are legitimate: 1, or the number of channels at input. Default: 1
......
...@@ -220,8 +220,8 @@ class BatchNorm2d(_BatchNorm): ...@@ -220,8 +220,8 @@ class BatchNorm2d(_BatchNorm):
of 0.9. of 0.9.
If :attr:`track_running_stats` is set to ``False``, this layer will not If :attr:`track_running_stats` is set to ``False``, this layer will not
keep running estimates, and batch statistics are instead used during keep running estimates, batch statistics is used during
evaluation time. evaluation time instead.
.. note:: .. note::
This :attr:`momentum` argument is different from one used in optimizer This :attr:`momentum` argument is different from one used in optimizer
...@@ -236,15 +236,14 @@ class BatchNorm2d(_BatchNorm): ...@@ -236,15 +236,14 @@ class BatchNorm2d(_BatchNorm):
Spatial Batch Normalization. Spatial Batch Normalization.
:type num_features: int :type num_features: int
:param num_features: usually the :math:`C` from an input of size :param num_features: usually :math:`C` from an input of shape
:math:`(N, C, H, W)` or the highest ranked dimension of an input with :math:`(N, C, H, W)` or the highest ranked dimension of an input
less than 4D. less than 4D.
:type eps: float :type eps: float
:param eps: a value added to the denominator for numerical stability. :param eps: a value added to the denominator for numerical stability.
Default: 1e-5 Default: 1e-5
:type momentum: float :type momentum: float
:param momentum: the value used for the `running_mean` and `running_var` :param momentum: the value used for the ``running_mean`` and ``running_var`` computation.
computation.
Default: 0.9 Default: 0.9
:type affine: bool :type affine: bool
:param affine: a boolean value that when set to True, this module has :param affine: a boolean value that when set to True, this module has
......
...@@ -99,8 +99,8 @@ class Conv2d(_ConvNd): ...@@ -99,8 +99,8 @@ class Conv2d(_ConvNd):
\sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
where :math:`\star` is the valid 2D cross-correlation operator, where :math:`\star` is the valid 2D cross-correlation operator,
:math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`N` is batch size, :math:`C` denotes number of channels,
:math:`H` is a height of input planes in pixels, and :math:`W` is :math:`H` is height of input planes in pixels, and :math:`W` is
width in pixels. width in pixels.
When `groups == in_channels` and `out_channels == K * in_channels`, When `groups == in_channels` and `out_channels == K * in_channels`,
...@@ -120,9 +120,8 @@ class Conv2d(_ConvNd): ...@@ -120,9 +120,8 @@ class Conv2d(_ConvNd):
:param padding: size of the paddings added to the input on both sides of its :param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1 :param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into, :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
so as to perform a "grouped convolution". When groups is not 1, ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
in_channels and out_channels must be divisible by groups,
and there would be an extra dimension at the beginning of the weight's and there would be an extra dimension at the beginning of the weight's
shape. Specifically, the shape of weight would be `(groups, shape. Specifically, the shape of weight would be `(groups,
out_channel // groups, in_channels // groups, *kernel_size)`. out_channel // groups, in_channels // groups, *kernel_size)`.
...@@ -130,9 +129,9 @@ class Conv2d(_ConvNd): ...@@ -130,9 +129,9 @@ class Conv2d(_ConvNd):
True True
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default:
`CROSS_CORRELATION` `CROSS_CORRELATION`
:param compute_mode: When set to `DEFAULT`, no special requirements will be :param compute_mode: When set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to `FLOAT32`, placed on the precision of intermediate results. When set to "FLOAT32",
float32 would be used for accumulator and intermediate result, but only "Float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype. effective when input and output are of float16 dtype.
Examples: Examples:
...@@ -236,7 +235,7 @@ class ConvTranspose2d(_ConvNd): ...@@ -236,7 +235,7 @@ class ConvTranspose2d(_ConvNd):
r"""Applies a 2D transposed convolution over an input tensor. r"""Applies a 2D transposed convolution over an input tensor.
This module is also known as a deconvolution or a fractionally-strided convolution. This module is also known as a deconvolution or a fractionally-strided convolution.
:class:`ConvTranspose2d` can ben seen as the gradient of :class:`Conv2d` operation :class:`ConvTranspose2d` can be seen as the gradient of :class:`Conv2d` operation
with respect to its input. with respect to its input.
Convolution usually reduces the size of input, while transposed convolution works Convolution usually reduces the size of input, while transposed convolution works
...@@ -252,8 +251,7 @@ class ConvTranspose2d(_ConvNd): ...@@ -252,8 +251,7 @@ class ConvTranspose2d(_ConvNd):
:param padding: size of the paddings added to the input on both sides of its :param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1 :param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into, :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``, ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and there would be an extra dimension at the beginning of the weight's and there would be an extra dimension at the beginning of the weight's
shape. Specifically, the shape of weight would be ``(groups, shape. Specifically, the shape of weight would be ``(groups,
...@@ -262,9 +260,9 @@ class ConvTranspose2d(_ConvNd): ...@@ -262,9 +260,9 @@ class ConvTranspose2d(_ConvNd):
True True
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default:
`CROSS_CORRELATION` `CROSS_CORRELATION`
:param compute_mode: When set to `DEFAULT`, no special requirements will be :param compute_mode: When set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to `FLOAT32`, placed on the precision of intermediate results. When set to "FLOAT32",
float32 would be used for accumulator and intermediate result, but only "Float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype. effective when input and output are of float16 dtype.
""" """
...@@ -342,7 +340,7 @@ class ConvTranspose2d(_ConvNd): ...@@ -342,7 +340,7 @@ class ConvTranspose2d(_ConvNd):
class LocalConv2d(Conv2d): class LocalConv2d(Conv2d):
r"""Applies a spatial convolution with untied kernels over an input 4D tensor. r"""Applies a spatial convolution with unshared kernels over an input 4D tensor.
It is also known as the locally connected layer. It is also known as the locally connected layer.
:param in_channels: number of input channels. :param in_channels: number of input channels.
...@@ -355,9 +353,9 @@ class LocalConv2d(Conv2d): ...@@ -355,9 +353,9 @@ class LocalConv2d(Conv2d):
:param stride: stride of the 2D convolution operation. Default: 1 :param stride: stride of the 2D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides of its :param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0 spatial dimensions. Only zero-padding is supported. Default: 0
:param groups: number of groups to divide input and output channels into, :param groups: number of groups into which the input and output channels are divided,
so as to perform a "grouped convolution". When groups is not 1, so as to perform a "grouped convolution". When ``groups`` is not 1,
in_channels and out_channels must be divisible by groups. ``in_channels`` and ``out_channels`` must be divisible by ``groups``.
The shape of weight is `(groups, output_height, output_width, The shape of weight is `(groups, output_height, output_width,
in_channels // groups, *kernel_size, out_channels // groups)`. in_channels // groups, *kernel_size, out_channels // groups)`.
""" """
......
...@@ -11,7 +11,7 @@ from .module import Module ...@@ -11,7 +11,7 @@ from .module import Module
class Dropout(Module): class Dropout(Module):
r"""Randomly set input elements to zeros with the probability :math:`drop\_prob` during training. r"""Randomly sets input elements to zeros with the probability :math:`drop\_prob` during training.
Commonly used in large networks to prevent overfitting. Commonly used in large networks to prevent overfitting.
Note that we perform dropout only during training, we also rescale(multiply) the output tensor Note that we perform dropout only during training, we also rescale(multiply) the output tensor
by :math:`\frac{1}{1 - drop\_prob}`. During inference :class:`~.Dropout` is equal to :class:`~.Identity`. by :math:`\frac{1}{1 - drop\_prob}`. During inference :class:`~.Dropout` is equal to :class:`~.Identity`.
......
...@@ -26,9 +26,9 @@ class Embedding(Module): ...@@ -26,9 +26,9 @@ class Embedding(Module):
:param num_embeddings: size of embedding dictionary. :param num_embeddings: size of embedding dictionary.
:param embedding_dim: size of each embedding vector. :param embedding_dim: size of each embedding vector.
:param padding_idx: should be set to None, not support now. :param padding_idx: should be set to None, not supportted now.
:param max_norm: should be set to None, not support now. :param max_norm: should be set to None, not supportted now.
:param norm_type: should be set to None, not support now. :param norm_type: should be set to None, not supportted now.
:param initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim). :param initial_weight: the learnable weights of the module of shape (num_embeddings, embedding_dim).
Examples: Examples:
...@@ -121,8 +121,8 @@ class Embedding(Module): ...@@ -121,8 +121,8 @@ class Embedding(Module):
r""" r"""
Creates Embedding instance from given 2-dimensional FloatTensor. Creates Embedding instance from given 2-dimensional FloatTensor.
:param embeddings: Tensor contained weight for the embedding. :param embeddings: tensor contained weight for the embedding.
:param freeze: If ``True``, the weight does not get updated during the learning process. Default: ``True``. :param freeze: if ``True``, the weight does not get updated during the learning process. Default: True.
:param padding_idx: should be set to None, not support Now. :param padding_idx: should be set to None, not support Now.
:param max_norm: should be set to None, not support Now. :param max_norm: should be set to None, not support Now.
:param norm_type: should be set to None, not support Now. :param norm_type: should be set to None, not support Now.
......
...@@ -18,48 +18,48 @@ from ..tensor import Tensor ...@@ -18,48 +18,48 @@ from ..tensor import Tensor
def fill_(tensor: Tensor, val: Union[float, int]) -> None: def fill_(tensor: Tensor, val: Union[float, int]) -> None:
"""Fill the given ``tensor`` with value ``val``. """Fills the given ``tensor`` with value ``val``.
:param tensor: An n-dimentional tensor to be initialized :param tensor: tensor to be initialized.
:param val: The value to be filled throughout the tensor :param val: value to be filled throughout the tensor.
""" """
tensor._reset(full(shape=tensor.shape, value=val, dtype=tensor.dtype)) tensor._reset(full(shape=tensor.shape, value=val, dtype=tensor.dtype))
def zeros_(tensor: Tensor) -> None: def zeros_(tensor: Tensor) -> None:
"""Fill the given ``tensor`` with scalar value `0`. """Fills the given ``tensor`` with scalar value `0`.
:param tensor: An n-dimentional tensor to be initialized :param tensor: tensor to be initialized.
""" """
fill_(tensor, 0) fill_(tensor, 0)
def ones_(tensor: Tensor) -> None: def ones_(tensor: Tensor) -> None:
"""Fill the given ``tensor`` with the scalar value `1`. """Fills the given ``tensor`` with the scalar value `1`.
:param tensor: An n-dimentional tensor to be initialized :param tensor: tensor to be initialized.
""" """
fill_(tensor, 1) fill_(tensor, 1)
def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None:
r"""Fill the given ``tensor`` with random value sampled from uniform distribution r"""Fills the given ``tensor`` with random value sampled from uniform distribution
:math:`\mathcal{U}(\text{a}, \text{b})`. :math:`\mathcal{U}(\text{a}, \text{b})`.
:param tensor: An n-dimentional tensor to be initialized :param tensor: tensor to be initialized.
:param a: Lower bound of the sampling interval :param a: lower bound of the sampling interval.
:param b: Upper bound of the sampling interval :param b: upper bound of the sampling interval.
""" """
tensor._reset(uniform(size=tensor.shape, low=a, high=b).astype(tensor.dtype)) tensor._reset(uniform(size=tensor.shape, low=a, high=b).astype(tensor.dtype))
def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
r"""Fill the given ``tensor`` with random value sampled from normal distribution r"""Fills the given ``tensor`` with random value sampled from normal distribution
:math:`\mathcal{N}(\text{mean}, \text{std}^2)`. :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
:param tensor: An n-dimentional tensor to be initialized :param tensor: tensor to be initialized.
:param mean: The mean of the normal distribution :param mean: mean of the normal distribution.
:param std: The standard deviation of the normal distribution :param std: standard deviation of the normal distribution.
""" """
tensor._reset(normal(size=tensor.shape, mean=mean, std=std).astype(tensor.dtype)) tensor._reset(normal(size=tensor.shape, mean=mean, std=std).astype(tensor.dtype))
...@@ -67,7 +67,7 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: ...@@ -67,7 +67,7 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
def calculate_gain( def calculate_gain(
nonlinearity: str, param: Optional[Union[int, float]] = None nonlinearity: str, param: Optional[Union[int, float]] = None
) -> float: ) -> float:
r"""Return a recommended gain value (see the table below) for the given nonlinearity r"""Returns a recommended gain value (see the table below) for the given nonlinearity
function. function.
================= ==================================================== ================= ====================================================
...@@ -81,8 +81,8 @@ def calculate_gain( ...@@ -81,8 +81,8 @@ def calculate_gain(
Leaky Relu :math:`\sqrt{\frac{2}{1 + {\text{negative}_\text{slope}}^2}}` Leaky Relu :math:`\sqrt{\frac{2}{1 + {\text{negative}_\text{slope}}^2}}`
================= ==================================================== ================= ====================================================
:param nonlinearity: Name of the non-linear function :param nonlinearity: name of the non-linear function.
:param param: Optional parameter for leaky_relu. Only effective when :param param: optional parameter for leaky_relu. Only effective when
``nonlinearity`` is "leaky_relu". ``nonlinearity`` is "leaky_relu".
""" """
...@@ -119,10 +119,10 @@ def calculate_gain( ...@@ -119,10 +119,10 @@ def calculate_gain(
def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
""" """
Calculate fan_in / fan_out value for given weight tensor. This function assumes Calculates fan_in / fan_out value for given weight tensor. This function assumes
input tensor is stored in NCHW format. input tensor is stored in ``NCHW`` format.
:param tensor: Weight tensor in NCHW format :param tensor: weight tensor in ``NCHW`` format.
""" """
shape = tensor.shape shape = tensor.shape
ndim = len(shape) ndim = len(shape)
...@@ -148,13 +148,13 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: ...@@ -148,13 +148,13 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]:
def calculate_correct_fan(tensor: Tensor, mode: str) -> float: def calculate_correct_fan(tensor: Tensor, mode: str) -> float:
""" """
Calculate fan_in or fan_out value for given weight tensor, depending on given Calculates fan_in / fan_out value for given weight tensor, depending on given
``mode``. ``mode``.
See :func:`calculate_fan_in_and_fan_out` for details. See :func:`calculate_fan_in_and_fan_out` for details.
:param tensor: Weight tensor in NCHW format :param tensor: weight tensor in ``NCHW`` format.
:param mode: ``'fan_in'`` or ``'fan_out'`` :param mode: "fan_in" or "fan_out".
""" """
mode = mode.lower() mode = mode.lower()
valid_modes = ["fan_in", "fan_out"] valid_modes = ["fan_in", "fan_out"]
...@@ -168,7 +168,7 @@ def calculate_correct_fan(tensor: Tensor, mode: str) -> float: ...@@ -168,7 +168,7 @@ def calculate_correct_fan(tensor: Tensor, mode: str) -> float:
def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
r"""Fill ``tensor`` with random values sampled from :math:`\mathcal{U}(-a, a)` r"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)`
where where
.. math:: .. math::
...@@ -178,8 +178,8 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: ...@@ -178,8 +178,8 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
`Understanding the difficulty of training deep feedforward neural networks` - `Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010). Glorot, X. & Bengio, Y. (2010).
:param tensor: An n-dimentional tensor to be initialized :param tensor: tensor to be initialized.
:param gain: Scaling factor for :math:`a`. :param gain: scaling factor for :math:`a`.
""" """
fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) fan_in, fan_out = calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
...@@ -188,7 +188,7 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: ...@@ -188,7 +188,7 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None:
def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None:
r"""Fill ``tensor`` with random values sampled from r"""Fills tensor with random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where :math:`\mathcal{N}(0, \text{std}^2)` where
.. math:: .. math::
...@@ -198,8 +198,8 @@ def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: ...@@ -198,8 +198,8 @@ def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None:
`Understanding the difficulty of training deep feedforward neural networks` - `Understanding the difficulty of training deep feedforward neural networks` -
Glorot, X. & Bengio, Y. (2010). Glorot, X. & Bengio, Y. (2010).
:param tensor: An n-dimentional tensor to be initialized :param tensor: tensor to be initialized.
:param gain: Scaling factor for :math:`std`. :param gain: scaling factor for :math:`std`.
""" """
fan_in, fan_out = calculate_fan_in_and_fan_out(tensor) fan_in, fan_out = calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
...@@ -209,7 +209,7 @@ def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: ...@@ -209,7 +209,7 @@ def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None:
def msra_uniform_( def msra_uniform_(
tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu"
) -> None: ) -> None:
r"""Fill ``tensor`` wilth random values sampled from r"""Fills tensor wilth random values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math:: .. math::
...@@ -219,13 +219,13 @@ def msra_uniform_( ...@@ -219,13 +219,13 @@ def msra_uniform_(
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet `Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification` classification`
:param tensor: An n-dimentional tensor to be initialized :param tensor: tensor to be initialized.
:param a: Optional parameter for calculating gain for leaky_relu. See :param a: optional parameter for calculating gain for leaky_relu. See
:func:`calculate_gain` for details. :func:`calculate_gain` for details.
:param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the :param mode: "fan_in" or "fan_out", used to calculate :math:`gain`, the
scaling factor for :math:`bound`. See :func:`calculate_fan_in_and_fan_out` for scaling factor for :math:`bound`. See :func:`calculate_fan_in_and_fan_out` for
details. details.
:param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. :param nonlinearity: name of the non-linear function used to calculate :math:`gain`.
See :func:`calculate_gain` for details. See :func:`calculate_gain` for details.
""" """
fan = calculate_correct_fan(tensor, mode) fan = calculate_correct_fan(tensor, mode)
...@@ -238,7 +238,7 @@ def msra_uniform_( ...@@ -238,7 +238,7 @@ def msra_uniform_(
def msra_normal_( def msra_normal_(
tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" tensor: Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu"
) -> None: ) -> None:
r"""Fill ``tensor`` wilth random values sampled from r"""Fills tensor wilth random values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where :math:`\mathcal{N}(0, \text{std}^2)` where
.. math:: .. math::
...@@ -248,13 +248,13 @@ def msra_normal_( ...@@ -248,13 +248,13 @@ def msra_normal_(
`Delving deep into rectifiers: Surpassing human-level performance on ImageNet `Delving deep into rectifiers: Surpassing human-level performance on ImageNet
classification` classification`
:param tensor: An n-dimentional tensor to be initialized :param tensor: tensor to be initialized
:param a: Optional parameter for calculating gain for leaky_relu. See :param a: optional parameter for calculating gain for leaky_relu. See
:func:`calculate_gain` for details. :func:`calculate_gain` for details.
:param mode: ``'fan_in'`` or ``'fan_out'``, used to calculate :math:`gain`, the :param mode: "fan_in" or "fan_out", used to calculate :math:`gain`, the
scaling factor for :math:`gain`. See :func:`calculate_fan_in_and_fan_out` for scaling factor for :math:`gain`. See :func:`calculate_fan_in_and_fan_out` for
details. details.
:param nonlinearity: Name of the non-linear function used to calculate :math:`gain`. :param nonlinearity: name of the non-linear function used to calculate :math:`gain`.
See :func:`calculate_gain` for details. See :func:`calculate_gain` for details.
""" """
fan = calculate_correct_fan(tensor, mode) fan = calculate_correct_fan(tensor, mode)
......
...@@ -25,7 +25,7 @@ class Linear(Module): ...@@ -25,7 +25,7 @@ class Linear(Module):
:param in_features: size of each input sample. :param in_features: size of each input sample.
:param out_features: size of each output sample. :param out_features: size of each output sample.
:param bias: If set to ``False``, the layer will not learn an additive bias. :param bias: if it's ``False``, the layer will not learn an additional ``bias``.
Default: ``True`` Default: ``True``
Examples: Examples:
......
...@@ -76,9 +76,7 @@ class Module(metaclass=ABCMeta): ...@@ -76,9 +76,7 @@ class Module(metaclass=ABCMeta):
pass pass
def register_forward_pre_hook(self, hook: Callable) -> HookHandler: def register_forward_pre_hook(self, hook: Callable) -> HookHandler:
"""Register a hook to handle forward inputs. `hook` should be a function """Registers a hook to handle forward inputs. `hook` should be a function.
Note that `inputs` keyword inputs
:param hook: a function that receive `module` and `inputs`, then return :param hook: a function that receive `module` and `inputs`, then return
a modified `inputs` or `None`. a modified `inputs` or `None`.
...@@ -87,7 +85,7 @@ class Module(metaclass=ABCMeta): ...@@ -87,7 +85,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_pre_hooks, hook) return HookHandler(self._forward_pre_hooks, hook)
def register_forward_hook(self, hook: Callable) -> HookHandler: def register_forward_hook(self, hook: Callable) -> HookHandler:
"""Register a hook to handle forward results. `hook` should be a function that """Registers a hook to handle forward results. `hook` should be a function that
receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`. receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`.
This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook. This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
...@@ -126,12 +124,12 @@ class Module(metaclass=ABCMeta): ...@@ -126,12 +124,12 @@ class Module(metaclass=ABCMeta):
returned iterable is guaranteed to be identical, as long as all the involved returned iterable is guaranteed to be identical, as long as all the involved
module objects' ``__dict__`` does not change thoughout those calls. module objects' ``__dict__`` does not change thoughout those calls.
:param recursive: Whether to recursively scan all the submodules. :param recursive: whether to recursively scan all the submodules.
:param with_key: Whether to yield keys along with yielded objects. :param with_key: whether to yield keys along with yielded objects.
:param with_parent: Whether to yield ``self`` along with yielded objects. :param with_parent: whether to yield ``self`` along with yielded objects.
:param prefix: The prefix appended to the yielded keys. :param prefix: prefix appended to the yielded keys.
:param predicate: The predicate function applied to scanned objects. :param predicate: the predication function applied to scanned objects.
:param seen: A dict that records whether a module has been traversed yet. :param seen: a dict that records whether a module has been traversed yet.
""" """
if seen is None: if seen is None:
seen = set([id(self)]) seen = set([id(self)])
...@@ -193,10 +191,10 @@ class Module(metaclass=ABCMeta): ...@@ -193,10 +191,10 @@ class Module(metaclass=ABCMeta):
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
) -> Iterable[Tuple[str, Parameter]]: ) -> Iterable[Tuple[str, Parameter]]:
"""Returns an iterable for key :class:`~.Parameter` pairs of the module, where """Returns an iterable for key :class:`~.Parameter` pairs of the module, where
``key`` is the dotted path from this module to the :class:`~.Parameter` . ``key`` is the dotted path from this module to the :class:`~.Parameter`.
:param prefix: The prefix prepended to the keys. :param prefix: prefix prepended to the keys.
:param recursive: If ``True``, returns all :class:`~.Parameter` within this :param recursive: if ``True``, returns all :class:`~.Parameter` within this
module, else only returns :class:`~.Parameter` that are direct attributes module, else only returns :class:`~.Parameter` that are direct attributes
of this module. of this module.
""" """
...@@ -225,7 +223,7 @@ class Module(metaclass=ABCMeta): ...@@ -225,7 +223,7 @@ class Module(metaclass=ABCMeta):
Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`. Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`.
:param recursive: If ``True``, returns all buffers within this :param recursive: if ``True``, returns all buffers within this
module, else only returns buffers that are direct attributes module, else only returns buffers that are direct attributes
of this module. of this module.
""" """
...@@ -241,8 +239,8 @@ class Module(metaclass=ABCMeta): ...@@ -241,8 +239,8 @@ class Module(metaclass=ABCMeta):
Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`. Buffer is defined to be :class:`~.Tensor` excluding :class:`~.Parameter`.
:param prefix: The prefix prepended to the keys. :param prefix: prefix prepended to the keys.
:param recursive: If ``True``, returns all buffers within this :param recursive: if ``True``, returns all buffers within this
module, else only returns buffers that are direct attributes module, else only returns buffers that are direct attributes
of this module. of this module.
""" """
...@@ -287,7 +285,7 @@ class Module(metaclass=ABCMeta): ...@@ -287,7 +285,7 @@ class Module(metaclass=ABCMeta):
module, including itself, where 'key' is the dotted path from this module to the module, including itself, where 'key' is the dotted path from this module to the
submodules. submodules.
:param prefix: The prefix prepended to the path. :param prefix: prefix prepended to the path.
""" """
if "with_parent" in kwargs and kwargs["with_parent"]: if "with_parent" in kwargs and kwargs["with_parent"]:
yield ("" if prefix is None else prefix), self, None yield ("" if prefix is None else prefix), self, None
...@@ -298,24 +296,24 @@ class Module(metaclass=ABCMeta): ...@@ -298,24 +296,24 @@ class Module(metaclass=ABCMeta):
) )
def apply(self, fn: "Callable[[Module], Any]") -> None: def apply(self, fn: "Callable[[Module], Any]") -> None:
"""Apply function ``fn`` to all the modules within this module, including """Applies function ``fn`` to all the modules within this module, including
itself. itself.
:param fn: The function to be applied on modules. :param fn: the function to be applied on modules.
""" """
for it in self.modules(): for it in self.modules():
fn(it) fn(it)
@deprecated(version="1.0") @deprecated(version="1.0")
def zero_grad(self) -> None: def zero_grad(self) -> None:
"""Set all parameters' grads to zero """Sets all parameters' grads to zero
""" """
for param in self.parameters(): for param in self.parameters():
if param.grad is not None: if param.grad is not None:
param.grad.reset_zero() param.grad.reset_zero()
def train(self, mode: bool = True, recursive: bool = True) -> None: def train(self, mode: bool = True, recursive: bool = True) -> None:
"""Set training mode of all the modules within this module (including itself) to """Sets training mode of all the modules within this module (including itself) to
``mode``. This effectively sets the ``training`` attributes of those modules ``mode``. This effectively sets the ``training`` attributes of those modules
to ``mode``, but only has effect on certain modules (e.g. to ``mode``, but only has effect on certain modules (e.g.
:class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`) :class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`)
...@@ -333,14 +331,14 @@ class Module(metaclass=ABCMeta): ...@@ -333,14 +331,14 @@ class Module(metaclass=ABCMeta):
self.apply(fn) self.apply(fn)
def eval(self) -> None: def eval(self) -> None:
"""Set training mode of all the modules within this module (including itself) to """Sets training mode of all the modules within this module (including itself) to
``False``. See :meth:`~.Module.train` for details. ``False``. See :meth:`~.Module.train` for details.
""" """
self.train(False) self.train(False)
def disable_quantize(self, value=True): def disable_quantize(self, value=True):
r""" r"""
Set ``module``'s ``quantize_disabled`` attribute and return ``module``. Sets ``module``'s ``quantize_disabled`` attribute and return ``module``.
Could be used as a decorator. Could be used as a decorator.
""" """
...@@ -353,7 +351,7 @@ class Module(metaclass=ABCMeta): ...@@ -353,7 +351,7 @@ class Module(metaclass=ABCMeta):
def replace_param( def replace_param(
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None self, params: dict, start_pos: int, seen: Optional[Set[int]] = None
): ):
"""Replace module's parameters with `params`, used by :class:`~.ParamPack` to """Replaces module's parameters with `params`, used by :class:`~.ParamPack` to
speedup multimachine training. speedup multimachine training.
""" """
offset = 0 offset = 0
...@@ -409,7 +407,7 @@ class Module(metaclass=ABCMeta): ...@@ -409,7 +407,7 @@ class Module(metaclass=ABCMeta):
state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]], state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]],
strict=True, strict=True,
): ):
r"""Load a given dictionary created by :func:`state_dict` into this module. r"""Loads a given dictionary created by :func:`state_dict` into this module.
If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys
returned by :func:`state_dict`. returned by :func:`state_dict`.
......
...@@ -18,7 +18,7 @@ class Linear(Float.Linear, QATModule): ...@@ -18,7 +18,7 @@ class Linear(Float.Linear, QATModule):
:param in_features: size of each input sample. :param in_features: size of each input sample.
:param out_features: size of each output sample. :param out_features: size of each output sample.
:param bias: If set to ``False``, the layer will not learn an additive bias. :param bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True`` Default: True
""" """
......
...@@ -15,7 +15,7 @@ from .module import QuantizedModule ...@@ -15,7 +15,7 @@ from .module import QuantizedModule
class Concat(QuantizedModule): class Concat(QuantizedModule):
r""" r"""
A :class:`~.QuantizedModule` to do quantized concat, inference only. A :class:`~.QuantizedModule` to do quantized concat, used for inference only.
""" """
def __init__(self, dtype=None): def __init__(self, dtype=None):
...@@ -29,7 +29,7 @@ class Concat(QuantizedModule): ...@@ -29,7 +29,7 @@ class Concat(QuantizedModule):
@classmethod @classmethod
def from_qat_module(cls, qat_module: QAT.Concat): def from_qat_module(cls, qat_module: QAT.Concat):
r""" r"""
return a :class:`~.QuantizedModule` instance converted from a Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance. :class:`~.QATModule` instance.
""" """
return cls(qat_module.get_activation_dtype()) return cls(qat_module.get_activation_dtype())
...@@ -18,10 +18,10 @@ from .module import QuantizedModule ...@@ -18,10 +18,10 @@ from .module import QuantizedModule
class Conv2d(Float.Conv2d, QuantizedModule): class Conv2d(Float.Conv2d, QuantizedModule):
r"""quantized version of :class:`~.qat.conv.Conv2d`.""" r"""Quantized version of :class:`~.qat.conv.Conv2d`."""
r"""Applies a 2D convolution over an quantized input tensor, inference only. r"""Applies a 2D convolution over a quantized input tensor, used for inference only.
The parameter is same with :class: `~.Conv2d` The parameter is same with :class: `~.Conv2d`.
""" """
def __init__( def __init__(
...@@ -101,7 +101,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): ...@@ -101,7 +101,7 @@ class Conv2d(Float.Conv2d, QuantizedModule):
class ConvRelu2d(Conv2d): class ConvRelu2d(Conv2d):
r"""quantized version of :class:`~.qat.conv.ConvRelu2d`.""" r"""Quantized version of :class:`~.qat.conv.ConvRelu2d`."""
def forward(self, inp): def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") return self.calc_conv_quantized(inp, nonlinear_mode="RELU")
...@@ -11,15 +11,15 @@ from .conv import Conv2d ...@@ -11,15 +11,15 @@ from .conv import Conv2d
class _ConvBnActivation2d(Conv2d): class _ConvBnActivation2d(Conv2d):
r"""Applies a 2D convolution over an quantized input tensor, inference only. r"""Applies a 2D convolution over a quantized input tensor, used for inference only.
The parameter is same with :class: `~.Conv2d` The parameter is same with :class: `~.Conv2d`.
""" """
@classmethod @classmethod
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d):
r""" r"""
return a :class:`~.QuantizedModule` instance converted from a Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance. :class:`~.QATModule` instance.
""" """
output_dtype = qat_module.get_activation_dtype() output_dtype = qat_module.get_activation_dtype()
...@@ -43,14 +43,14 @@ class _ConvBnActivation2d(Conv2d): ...@@ -43,14 +43,14 @@ class _ConvBnActivation2d(Conv2d):
class ConvBn2d(_ConvBnActivation2d): class ConvBn2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn.ConvBn2d`.""" r"""Quantized version of :class:`~.qat.conv_bn.ConvBn2d`."""
def forward(self, inp): def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")
class ConvBnRelu2d(_ConvBnActivation2d): class ConvBnRelu2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn.ConvBnRelu2d`.""" r"""Quantized version of :class:`~.qat.conv_bn.ConvBnRelu2d`."""
def forward(self, inp): def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") return self.calc_conv_quantized(inp, nonlinear_mode="RELU")
...@@ -13,7 +13,7 @@ from .module import QuantizedModule ...@@ -13,7 +13,7 @@ from .module import QuantizedModule
class Elemwise(QuantizedModule): class Elemwise(QuantizedModule):
r"""quantized version of :class:`~.qat.elemwise.Elemwise`.""" r"""Quantized version of :class:`~.qat.elemwise.Elemwise`."""
_elemwise_multi_type_mode = P.ElemwiseMultiType.Mode _elemwise_multi_type_mode = P.ElemwiseMultiType.Mode
...@@ -30,7 +30,7 @@ class Elemwise(QuantizedModule): ...@@ -30,7 +30,7 @@ class Elemwise(QuantizedModule):
@classmethod @classmethod
def from_qat_module(cls, qat_module: QAT.Elemwise): def from_qat_module(cls, qat_module: QAT.Elemwise):
r""" r"""
return a :class:`~.QuantizedModule` instance converted from a Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance. :class:`~.QATModule` instance.
""" """
return cls(qat_module.method.name, qat_module.get_activation_dtype()) return cls(qat_module.method.name, qat_module.get_activation_dtype())
...@@ -15,7 +15,7 @@ from .module import QuantizedModule ...@@ -15,7 +15,7 @@ from .module import QuantizedModule
class Linear(QuantizedModule): class Linear(QuantizedModule):
r"""quantized version of :class:`~.qat.linear.Linear`.""" r"""Quantized version of :class:`~.qat.linear.Linear`."""
def __init__( def __init__(
self, dtype: np.dtype = None, self, dtype: np.dtype = None,
...@@ -40,7 +40,7 @@ class Linear(QuantizedModule): ...@@ -40,7 +40,7 @@ class Linear(QuantizedModule):
@classmethod @classmethod
def from_qat_module(cls, qat_module: QAT.Linear): def from_qat_module(cls, qat_module: QAT.Linear):
r""" r"""
return a :class:`~.QuantizedModule` instance converted from a Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance. :class:`~.QATModule` instance.
""" """
output_dtype = qat_module.get_activation_dtype() output_dtype = qat_module.get_activation_dtype()
......
...@@ -26,6 +26,6 @@ class QuantizedModule(Module): ...@@ -26,6 +26,6 @@ class QuantizedModule(Module):
@abstractmethod @abstractmethod
def from_qat_module(cls, qat_module: QATModule): def from_qat_module(cls, qat_module: QATModule):
r""" r"""
return a :class:`~.QuantizedModule` instance converted from a Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance. :class:`~.QATModule` instance.
""" """
...@@ -11,7 +11,7 @@ from .module import QuantizedModule ...@@ -11,7 +11,7 @@ from .module import QuantizedModule
class QuantStub(QuantizedModule): class QuantStub(QuantizedModule):
r""" r"""
quantized version of :class:`~.qat.quant_dequant.QuantStub`, Quantized version of :class:`~.qat.quant_dequant.QuantStub`,
will convert input to quantized dtype. will convert input to quantized dtype.
""" """
...@@ -25,7 +25,7 @@ class QuantStub(QuantizedModule): ...@@ -25,7 +25,7 @@ class QuantStub(QuantizedModule):
@classmethod @classmethod
def from_qat_module(cls, qat_module: QAT.QuantStub): def from_qat_module(cls, qat_module: QAT.QuantStub):
r""" r"""
return a :class:`~.QuantizedModule` instance converted from a Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance. :class:`~.QATModule` instance.
""" """
return cls(qat_module.get_activation_dtype()) return cls(qat_module.get_activation_dtype())
...@@ -33,7 +33,7 @@ class QuantStub(QuantizedModule): ...@@ -33,7 +33,7 @@ class QuantStub(QuantizedModule):
class DequantStub(QuantizedModule): class DequantStub(QuantizedModule):
r""" r"""
quantized version of :class:`~.qat.quant_dequant.DequantStub`, Quantized version of :class:`~.qat.quant_dequant.DequantStub`,
will restore quantized input to float32 dtype. will restore quantized input to float32 dtype.
""" """
...@@ -43,7 +43,7 @@ class DequantStub(QuantizedModule): ...@@ -43,7 +43,7 @@ class DequantStub(QuantizedModule):
@classmethod @classmethod
def from_qat_module(cls, qat_module: QAT.DequantStub): def from_qat_module(cls, qat_module: QAT.DequantStub):
r""" r"""
return a :class:`~.QuantizedModule` instance converted from a Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance. :class:`~.QATModule` instance.
""" """
return cls() return cls()
...@@ -22,13 +22,13 @@ class Adadelta(Optimizer): ...@@ -22,13 +22,13 @@ class Adadelta(Optimizer):
:param params: iterable of parameters to optimize or dicts defining :param params: iterable of parameters to optimize or dicts defining
parameter groups. parameter groups.
:param lr: coefficient that scale delta before it is applied :param lr: coefficient that scales delta before it is applied
to the parameters (default: 1.0). to the parameters. Default: 1.0
:param rho: coefficient used for computing a running average :param rho: coefficient used for computing a running average
of squared gradients (default: 0.9). of squared gradients. Default: 0.9
:param eps: term added to the denominator to improve :param eps: term added to the denominator to improve
numerical stability (default: 1e-6). numerical stability. Default: 1e-6
:param weight_decay: weight decay (L2 penalty) (default: 0). :param weight_decay: weight decay (L2 penalty). Default: 0
""" """
def __init__( def __init__(
......
...@@ -23,12 +23,12 @@ class Adagrad(Optimizer): ...@@ -23,12 +23,12 @@ class Adagrad(Optimizer):
:param params: iterable of parameters to optimize or dicts defining :param params: iterable of parameters to optimize or dicts defining
parameter groups. parameter groups.
:param lr: coefficient that scale delta before it is applied :param lr: coefficient that scales delta before it is applied
to the parameters (default: 1e-2). to the parameters. Default: 1e-2
:param lr_decay: learning rate decay (default: 0) :param lr_decay: learning rate decay. Default: 0
:param eps: term added to the denominator to improve :param eps: term added to the denominator to improve
numerical stability (default: 1e-10). numerical stability. Default: 1e-10
:param weight_decay: weight decay (L2 penalty) (default: 0). :param weight_decay: weight decay (L2 penalty). Default: 0
""" """
def __init__( def __init__(
......
...@@ -14,8 +14,8 @@ from .optimizer import Optimizer ...@@ -14,8 +14,8 @@ from .optimizer import Optimizer
class LRScheduler(metaclass=ABCMeta): class LRScheduler(metaclass=ABCMeta):
r"""Base class for all learning rate based schedulers. r"""Base class for all learning rate based schedulers.
:param optimizer: Wrapped optimizer. :param optimizer: wrapped optimizer.
:param current_epoch: The index of current epoch. Default: -1 :param current_epoch: the index of current epoch. Default: -1
""" """
def __init__( # pylint: disable=too-many-branches def __init__( # pylint: disable=too-many-branches
...@@ -53,7 +53,8 @@ class LRScheduler(metaclass=ABCMeta): ...@@ -53,7 +53,8 @@ class LRScheduler(metaclass=ABCMeta):
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
r"""Loads the schedulers state. r"""Loads the schedulers state.
:param state_dict (dict): scheduler state. :type state_dict: dict
:param state_dict: scheduler state.
""" """
raise NotImplementedError raise NotImplementedError
......
...@@ -17,10 +17,12 @@ class MultiStepLR(LRScheduler): ...@@ -17,10 +17,12 @@ class MultiStepLR(LRScheduler):
r"""Decays the learning rate of each parameter group by gamma once the r"""Decays the learning rate of each parameter group by gamma once the
number of epoch reaches one of the milestones. number of epoch reaches one of the milestones.
:param optimizer: Wrapped optimizer. :param optimizer: wrapped optimizer.
:param milestones (list): List of epoch indices. Must be increasing. :type milestones: list
:param gamma (float): Multiplicative factor of learning rate decay. Default: 0.1. :param milestones: list of epoch indices which should be increasing.
:param current_epoch: The index of current epoch. Default: -1. :type gamma: float
:param gamma: multiplicative factor of learning rate decay. Default: 0.1
:param current_epoch: the index of current epoch. Default: -1
""" """
def __init__( def __init__(
...@@ -55,7 +57,8 @@ class MultiStepLR(LRScheduler): ...@@ -55,7 +57,8 @@ class MultiStepLR(LRScheduler):
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
r"""Loads the schedulers state. r"""Loads the schedulers state.
:param state_dict (dict): scheduler state. :type state_dict: dict
:param state_dict: scheduler state.
""" """
tmp_dict = {} tmp_dict = {}
for key in ["milestones", "gamma", "current_epoch"]: for key in ["milestones", "gamma", "current_epoch"]:
......
...@@ -22,10 +22,10 @@ class _FakeQuantize(Module): ...@@ -22,10 +22,10 @@ class _FakeQuantize(Module):
r""" r"""
A Basic Fake Quant module. A Basic Fake Quant module.
:param dtype: A string indicating the target quantization type of input. :param dtype: a string indicating the target quantization type of input.
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, :param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation. instead of 1 greater. Usually True for weight and False for activation.
:param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. :param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
""" """
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
......
...@@ -21,8 +21,8 @@ class Observer(Module): ...@@ -21,8 +21,8 @@ class Observer(Module):
r""" r"""
A base class for Observer Module. A base class for Observer Module.
:param dtype: a string indicating to collect scale and zero_point of which dtype :param dtype: a string indicating to collect scale and zero_point of which dtype.
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, :param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation. instead of 1 greater. Usually True for weight and False for activation.
""" """
......
...@@ -63,7 +63,7 @@ qparam_dict = { ...@@ -63,7 +63,7 @@ qparam_dict = {
def get_qparam_dict(mode: QuantMode): def get_qparam_dict(mode: QuantMode):
"""Return the quantization parameters dictory according to the mode. """Return the quantization parameters dictionary according to the mode.
""" """
return qparam_dict.get(mode, None) return qparam_dict.get(mode, None)
...@@ -91,7 +91,7 @@ def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor ...@@ -91,7 +91,7 @@ def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor
def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
"""Apply fake quantization to bias, the special scale from input tensor """Apply fake quantization to bias, with the special scale from input tensor
and weight tensor, the quantized type set to qint32 also. and weight tensor, the quantized type set to qint32 also.
:param bias: the bias tensor which need to be faked. :param bias: the bias tensor which need to be faked.
......
...@@ -21,12 +21,12 @@ __all__ = ["normal", "uniform"] ...@@ -21,12 +21,12 @@ __all__ = ["normal", "uniform"]
def normal( def normal(
mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
) -> Tensor: ) -> Tensor:
r"""Random variable with Gaussian distribution $N(\mu, \sigma)$ r"""Random variable with Gaussian distribution :math:`N(\mu, \sigma)`.
:param size: Output tensor size :param size: output tensor size.
:param mean: The mean or expectation of the distribution :param mean: the mean or expectation of the distribution.
:param std: The standard deviation of the distribution (variance = $\sigma ^ 2$) :param std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`).
:return: The output tensor :return: the output tensor.
Examples: Examples:
...@@ -59,12 +59,12 @@ def normal( ...@@ -59,12 +59,12 @@ def normal(
def uniform( def uniform(
low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
) -> Tensor: ) -> Tensor:
r"""Random variable with uniform distribution $U(0, 1)$ r"""Random variable with uniform distribution $U(0, 1)$.
:param size: Output tensor size :param size: output tensor size.
:param low: Lower range :param low: lower range.
:param high: Upper range :param high: upper range.
:return: The output tensor :return: the output tensor.
Examples: Examples:
......
...@@ -23,16 +23,16 @@ HTTP_CONNECTION_TIMEOUT = 5 ...@@ -23,16 +23,16 @@ HTTP_CONNECTION_TIMEOUT = 5
class HTTPDownloadError(BaseException): class HTTPDownloadError(BaseException):
"""The class that represents http request error""" """The class that represents http request error."""
def download_from_url(url: str, dst: str, http_read_timeout=120): def download_from_url(url: str, dst: str, http_read_timeout=120):
""" """
Downloads file from given url to ``dst`` Downloads file from given url to ``dst``.
:param url: source URL :param url: source URL.
:param dst: saving path :param dst: saving path.
:param http_read_timeout: how many seconds to wait for data before giving up :param http_read_timeout: how many seconds to wait for data before giving up.
""" """
dst = os.path.expanduser(dst) dst = os.path.expanduser(dst)
dst_dir = os.path.dirname(dst) dst_dir = os.path.dirname(dst)
......
...@@ -73,6 +73,6 @@ _max_recursion_limit_context_manager = AlternativeRecursionLimit(2 ** 31 - 1) ...@@ -73,6 +73,6 @@ _max_recursion_limit_context_manager = AlternativeRecursionLimit(2 ** 31 - 1)
def max_recursion_limit(): def max_recursion_limit():
r"""Sets recursion limit to the max possible value r"""Sets recursion limit to the max possible value.
""" """
return _max_recursion_limit_context_manager return _max_recursion_limit_context_manager
...@@ -12,13 +12,13 @@ import numpy as np ...@@ -12,13 +12,13 @@ import numpy as np
def load_tensor_binary(fobj): def load_tensor_binary(fobj):
"""load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual """Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual
tensor value dump is implemented by ``mgb::debug::dump_tensor``. tensor value dump is implemented by ``mgb::debug::dump_tensor``.
Multiple values can be compared by ``tools/compare_binary_iodump.py``. Multiple values can be compared by ``tools/compare_binary_iodump.py``.
:param fobj: file object, or a string that contains the file name :param fobj: file object, or a string that contains the file name.
:return: tuple ``(tensor_value, tensor_name)`` :return: tuple ``(tensor_value, tensor_name)``.
""" """
if isinstance(fobj, str): if isinstance(fobj, str):
with open(fobj, "rb") as fin: with open(fobj, "rb") as fin:
......
...@@ -16,7 +16,7 @@ import numpy as np ...@@ -16,7 +16,7 @@ import numpy as np
class NonExistNum: class NonExistNum:
"""An object that behaves like a number but means a field does not exist; It is """An object that behaves like a number but means a field does not exist; It is
always greater than any real number always greater than any real number.
""" """
def __truediv__(self, _): def __truediv__(self, _):
...@@ -69,12 +69,12 @@ class OprProfRst: ...@@ -69,12 +69,12 @@ class OprProfRst:
footprint = None footprint = None
"""A mapping from ``"memory"`` or ``"computation"`` to the actual number """A mapping from ``"memory"`` or ``"computation"`` to the actual number
of corresponding operations""" of corresponding operations."""
def __init__(self, entry: dict): def __init__(self, entry: dict):
"""Opr profiling initialization, which sets up name, type and id of opr_info. """Opr profiling initialization, which sets up name, type and id of opr_info.
:param entry: profiling json exec_graph items :param entry: profiling json exec_graph items.
""" """
assert isinstance(entry, dict) assert isinstance(entry, dict)
self.opr_info = collections.OrderedDict() self.opr_info = collections.OrderedDict()
...@@ -84,7 +84,7 @@ class OprProfRst: ...@@ -84,7 +84,7 @@ class OprProfRst:
self.footprint = collections.defaultdict(NonExistNum) self.footprint = collections.defaultdict(NonExistNum)
def update_device_prof_info(self, dev_time: dict): def update_device_prof_info(self, dev_time: dict):
"""Updates device profiling info """Updates device profiling info.
:param dev_time: device time for single opr, :param dev_time: device time for single opr,
is an attribute of profiling result. is an attribute of profiling result.
...@@ -93,7 +93,7 @@ class OprProfRst: ...@@ -93,7 +93,7 @@ class OprProfRst:
self.time_dict["device"].append(copy.deepcopy(dev_time)) self.time_dict["device"].append(copy.deepcopy(dev_time))
def update_host_prof_info(self, host_time: dict): def update_host_prof_info(self, host_time: dict):
"""Updates host profiling info """Updates host profiling info.
:param host_time: host time for single opr, :param host_time: host time for single opr,
is an attribute of profiling result. is an attribute of profiling result.
...@@ -102,7 +102,7 @@ class OprProfRst: ...@@ -102,7 +102,7 @@ class OprProfRst:
self.time_dict["host"].append(copy.deepcopy(host_time)) self.time_dict["host"].append(copy.deepcopy(host_time))
def update_footprint(self, footprint: dict): def update_footprint(self, footprint: dict):
"""Updates opr footprint """Updates opr footprint.
:param footprint: footprint for single opr, :param footprint: footprint for single opr,
is an attribute of profiling result. is an attribute of profiling result.
...@@ -128,7 +128,7 @@ class Record: ...@@ -128,7 +128,7 @@ class Record:
] ]
def __init__(self, time: float, info: dict, footprint: dict): def __init__(self, time: float, info: dict, footprint: dict):
"""Initializes single record """Initializes single record.
:param time: opr running time, evaluated by applying users providing :param time: opr running time, evaluated by applying users providing
function to OprProfRst. function to OprProfRst.
...@@ -153,7 +153,7 @@ class Record: ...@@ -153,7 +153,7 @@ class Record:
self.opr_id = int(self.opr_id) self.opr_id = int(self.opr_id)
def get_column_by_name(self, name: str = None): def get_column_by_name(self, name: str = None):
"""extracts column value by its column name """Extracts column value by its column name.
:param name: column name, None for time. :param name: column name, None for time.
""" """
...@@ -165,7 +165,7 @@ class Record: ...@@ -165,7 +165,7 @@ class Record:
class ProfileAnalyzer: class ProfileAnalyzer:
def __init__(self, obj: dict, opr_filter: Callable = lambda opr, inp, out: True): def __init__(self, obj: dict, opr_filter: Callable = lambda opr, inp, out: True):
"""Initializes ProfileAnalyzer """Initializes ProfileAnalyzer.
:param obj: dict dumped from json str. :param obj: dict dumped from json str.
:param opr_filter: function that filter oprs. :param opr_filter: function that filter oprs.
...@@ -202,11 +202,11 @@ class ProfileAnalyzer: ...@@ -202,11 +202,11 @@ class ProfileAnalyzer:
def _aggregate( def _aggregate(
self, records: List[Record], aop: Union[str, Callable], atype: Optional[str] self, records: List[Record], aop: Union[str, Callable], atype: Optional[str]
) -> List[Record]: ) -> List[Record]:
"""Aggregate operation """Aggregate operation.
:param records: selected records :param records: selected records.
:param aop: aggregate operation, if aop is str, we would replace it :param aop: aggregate operation, if aop is str, we would replace it
with associated numpy function wth aop name" with associated numpy function wth aop name".
:param atype: the type aggregated by, None for aggregating all into single :param atype: the type aggregated by, None for aggregating all into single
record. record.
""" """
...@@ -247,10 +247,10 @@ class ProfileAnalyzer: ...@@ -247,10 +247,10 @@ class ProfileAnalyzer:
return rst return rst
def _sort(self, records: List[Record], sort_by: str) -> List[Record]: def _sort(self, records: List[Record], sort_by: str) -> List[Record]:
"""sort operation """Sort operation.
:param records: the records after aggregate operation. :param records: the records after aggregate operation.
:param sort_by: keyword for sorting the list :param sort_by: keyword for sorting the list.
""" """
if sort_by is None: if sort_by is None:
return records return records
...@@ -271,14 +271,14 @@ class ProfileAnalyzer: ...@@ -271,14 +271,14 @@ class ProfileAnalyzer:
sort_by: str = None, sort_by: str = None,
top_k: int = 0, top_k: int = 0,
) -> List[Record]: ) -> List[Record]:
"""Select operation """Select operation.
:param time_func: time_func provided by user, would apply to every :param time_func: time_func provided by user, would apply to every
OprProfRst OprProfRst.
:param opr_filter: filter satisfied operatiors. :param opr_filter: filter satisfied operatiors.
:param aggregate: function that apply to list of records which are :param aggregate: function that apply to list of records which are
aggregated by atype aggregated by atype.
:param aggregate_by: the type aggregated by :param aggregate_by: the type aggregated by.
:param sort_by: keyword for sorting all records. :param sort_by: keyword for sorting all records.
:param top_k: specify the maximum number of records. :param top_k: specify the maximum number of records.
:return: the records that go through select, aggregate, sort. :return: the records that go through select, aggregate, sort.
...@@ -304,18 +304,18 @@ class TimeFuncHelper: ...@@ -304,18 +304,18 @@ class TimeFuncHelper:
@staticmethod @staticmethod
def _eval_time(prof_type, end_key, func, opr_prof): def _eval_time(prof_type, end_key, func, opr_prof):
"""Eval time """Eval time.
:type prof_type: str :type prof_type: str
:param prof_type: 'host' or 'device' :param prof_type: 'host' or 'device'.
:type end_key: str :type end_key: str
:param end_key: 'kern' or 'end' :param end_key: 'kern' or 'end'.
:type func: function :type func: function
:param func: apply to list of all ``thread`` of ``gpu`` time. :param func: apply to list of all ``thread`` of ``gpu`` time.
:type opr_prof: `class OprProfRst` :type opr_prof: `class OprProfRst`
:param opr_prof: operator profiling result :param opr_prof: operator profiling result.
:rtype: float :rtype: float
:return: time :return: time.
""" """
if prof_type not in opr_prof.time_dict: if prof_type not in opr_prof.time_dict:
...@@ -327,10 +327,10 @@ class TimeFuncHelper: ...@@ -327,10 +327,10 @@ class TimeFuncHelper:
def eval_time_func(prof_type: str, end_key: str, func: Callable) -> float: def eval_time_func(prof_type: str, end_key: str, func: Callable) -> float:
"""Eval oprerator profile time. """Eval oprerator profile time.
:param prof_type: 'host' or 'device' :param prof_type: 'host' or 'device'.
:param end_key: 'kern' or 'end' :param end_key: 'kern' or 'end'.
:param func: apply to list of all ``thread`` of ``gpu`` time. :param func: apply to list of all ``thread`` of ``gpu`` time.
:return: Eval time results :return: eval time results.
""" """
return functools.partial(TimeFuncHelper._eval_time, prof_type, end_key, func) return functools.partial(TimeFuncHelper._eval_time, prof_type, end_key, func)
...@@ -338,18 +338,18 @@ class TimeFuncHelper: ...@@ -338,18 +338,18 @@ class TimeFuncHelper:
def _min_start( def _min_start(
prof_type, end_key, func, opr_prof prof_type, end_key, func, opr_prof
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""Eval minimum start time """Eval minimum start time.
:type prof_type: str :type prof_type: str
:param prof_type: 'host' or 'device' :param prof_type: 'host' or 'device'.
:type end_key: str :type end_key: str
:param end_key: 'kern' or 'end' :param end_key: 'kern' or 'end'.
:type func: function :type func: function
:param func: apply to list of all ``thread`` of ``gpu`` time. :param func: apply to list of all ``thread`` of ``gpu`` time.
:type opr_prof: `class OprProfRst` :type opr_prof: `class OprProfRst`
:param opr_prof: operator profiling result :param opr_prof: operator profiling result.
:rtype: float :rtype: float
:return: time :return: time.
""" """
if prof_type not in opr_prof.time_dict: if prof_type not in opr_prof.time_dict:
return None return None
...@@ -360,12 +360,12 @@ class TimeFuncHelper: ...@@ -360,12 +360,12 @@ class TimeFuncHelper:
def min_start_func( def min_start_func(
prof_type: str, end_key: str, func: Callable prof_type: str, end_key: str, func: Callable
) -> float: # pylint: disable=unused-argument ) -> float: # pylint: disable=unused-argument
"""Eval oprerator profile min start time """Eval oprerator profile min start time.
:param prof_type: 'host' or 'device' :param prof_type: 'host' or 'device'.
:param end_key: 'kern' or 'end' :param end_key: 'kern' or 'end'.
:param func: apply to list of all ``thread`` of ``gpu`` time. :param func: apply to list of all ``thread`` of ``gpu`` time.
:return: Eval time results :return: eval time results.
""" """
return functools.partial(TimeFuncHelper._min_start, prof_type, end_key, func) return functools.partial(TimeFuncHelper._min_start, prof_type, end_key, func)
...@@ -374,15 +374,15 @@ class TimeFuncHelper: ...@@ -374,15 +374,15 @@ class TimeFuncHelper:
"""Eval maximum end time """Eval maximum end time
:type prof_type: str :type prof_type: str
:param prof_type: 'host' or 'device' :param prof_type: 'host' or 'device'.
:type end_key: str :type end_key: str
:param end_key: 'kern' or 'end' :param end_key: 'kern' or 'end'.
:type func: function :type func: function
:param func: apply to list of all ``thread`` of ``gpu`` time. :param func: apply to list of all ``thread`` of ``gpu`` time.
:type opr_prof: `class OprProfRst` :type opr_prof: `class OprProfRst`
:param opr_prof: operator profiling result :param opr_prof: operator profiling result.
:rtype: float :rtype: float
:return: time :return: time.
""" """
if prof_type not in opr_prof.time_dict: if prof_type not in opr_prof.time_dict:
return None return None
...@@ -391,11 +391,11 @@ class TimeFuncHelper: ...@@ -391,11 +391,11 @@ class TimeFuncHelper:
@staticmethod @staticmethod
def max_end_func(prof_type: str, end_key: str, func: Callable) -> float: def max_end_func(prof_type: str, end_key: str, func: Callable) -> float:
"""Eval oprerator profile max end time """Eval oprerator profile max end time.
:param prof_type: 'host' or 'device' :param prof_type: 'host' or 'device'.
:param end_key: 'kern' or 'end' :param end_key: 'kern' or 'end'.
:param func: apply to list of all ``thread`` of ``gpu`` time. :param func: apply to list of all ``thread`` of ``gpu`` time.
:return: Eval time results :return: eval time results.
""" """
return functools.partial(TimeFuncHelper._max_end, prof_type, end_key, func) return functools.partial(TimeFuncHelper._max_end, prof_type, end_key, func)
...@@ -23,7 +23,7 @@ class Profiler: ...@@ -23,7 +23,7 @@ class Profiler:
Profile graph execution in imperative mode. Profile graph execution in imperative mode.
:type path: Optional[str] :type path: Optional[str]
:param path: default path for profiler to dump :param path: default path for profiler to dump.
Examples: Examples:
......
...@@ -7,17 +7,15 @@ class TensorSanityCheck: ...@@ -7,17 +7,15 @@ class TensorSanityCheck:
Examples: Examples:
.. testcode:: .. code-block:: python
from megengine import tensor from megengine import tensor
from megengine.utils.tensor_sanity_check import TensorSanityCheck from megengine.utils.tensor_sanity_check import TensorSanityCheck
with TensorSanityCheck() as checker: with TensorSanityCheck() as checker:
a = tensor([1, 2]) a = tensor([1, 2])
b = tensor([3, 4]) b = tensor([3, 4])
c = a + b c = a + b
print(c.numpy())
.. testoutput::
[4 6]
""" """
def __init__(self): def __init__(self):
......
...@@ -11,10 +11,10 @@ import functools ...@@ -11,10 +11,10 @@ import functools
def get_ndtuple(value, *, n, allow_zero=True): def get_ndtuple(value, *, n, allow_zero=True):
r"""Converts possibly 1D tuple to nd tuple r"""Converts possibly 1D tuple to nd tuple.
:type allow_zero: bool :type allow_zero: bool
:param allow_zero: whether to allow zero tuple value""" :param allow_zero: whether to allow zero tuple value."""
if not isinstance(value, collections.abc.Iterable): if not isinstance(value, collections.abc.Iterable):
value = int(value) value = int(value)
value = tuple([value for i in range(n)]) value = tuple([value for i in range(n)])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册