diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index cda2306678b4568094e3ab09172584cc474fa342..55099840cbbfc229b4aae85e49ac8196cde1b189 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -65,7 +65,7 @@ def _elwise(*args, mode): def _matmul(inp1, inp2): op = builtin.MatrixMul( - transposeA=False, transposeB=False, compute_mode="DEFAULT", format="DEFAULT" + transposeA=False, transposeB=False, compute_mode="default", format="default" ) inp1, inp2 = utils.convert_inputs(inp1, inp2) (result,) = apply(op, inp1, inp2) @@ -178,7 +178,7 @@ def _reduce(mode): def f(self, axis=None, keepdims: bool = False): data = self (data,) = utils.convert_inputs(data) - if mode == "MEAN": + if mode == "mean": data = data.astype("float32") elif self.dtype == np.bool_: data = data.astype("int32") @@ -204,7 +204,7 @@ def _reduce(mode): if not keepdims: result = _remove_axis(result, axis) if self.dtype == np.bool_: - if mode in ["MIN", "MAX"]: + if mode in ["min", "max"]: result = result.astype("bool") if axis is None or self.ndim == 1: setscalar(result) @@ -479,7 +479,7 @@ class ArrayMethodMixin(abc.ABC): 10.0 """ - return _reduce("SUM")(self, axis, keepdims) + return _reduce("sum")(self, axis, keepdims) def prod(self, axis=None, keepdims: bool = False): r""" @@ -512,7 +512,7 @@ class ArrayMethodMixin(abc.ABC): 24.0 """ - return _reduce("PRODUCT")(self, axis, keepdims) + return _reduce("product")(self, axis, keepdims) def min(self, axis=None, keepdims: bool = False): r""" @@ -545,7 +545,7 @@ class ArrayMethodMixin(abc.ABC): 1.0 """ - return _reduce("MIN")(self, axis, keepdims) + return _reduce("min")(self, axis, keepdims) def max(self, axis=None, keepdims: bool = False): r""" @@ -578,7 +578,7 @@ class ArrayMethodMixin(abc.ABC): 4.0 """ - return _reduce("MAX")(self, axis, keepdims) + return _reduce("max")(self, axis, keepdims) def mean(self, axis=None, keepdims: bool = False): r""" @@ -611,4 +611,4 @@ class ArrayMethodMixin(abc.ABC): 2.5 """ - return _reduce("MEAN")(self, axis, keepdims) + return _reduce("mean")(self, axis, keepdims) diff --git a/imperative/python/megengine/functional/loss.py b/imperative/python/megengine/functional/loss.py index 7711bf203563e22f20d92d62e66b7fafc385e06c..e2a1484508ad52d5d9bc4ec16a7768bb3a4e4208 100644 --- a/imperative/python/megengine/functional/loss.py +++ b/imperative/python/megengine/functional/loss.py @@ -267,6 +267,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: 1.5 """ + norm = norm.upper() assert norm in ["L1", "L2"], "norm must be L1 or L2" # Converts binary labels to -1/1 labels. loss = relu(1.0 - pred * label) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 637434356200e0c677c507b454e4ae517726741a..cf6ecebf66f2763e8de556d2e94c0b21e9b11cb8 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -604,9 +604,9 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor: """ assert len(inp.shape) <= 2, "Input should be 1d or 2d" if descending: - order = "DESCENDING" + order = "descending" else: - order = "ASCENDING" + order = "ascending" op = builtin.Argsort(order=order) if len(inp.shape) == 1: @@ -646,9 +646,9 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: """ assert len(inp.shape) <= 2, "Input should be 1d or 2d" if descending: - order = "DESCENDING" + order = "descending" else: - order = "ASCENDING" + order = "ascending" op = builtin.Argsort(order=order) if len(inp.shape) == 1: @@ -699,11 +699,11 @@ def topk( inp = -inp if kth_only: - mode = "KTH_ONLY" + mode = "kth_only" elif no_sort: - mode = "VALUE_IDX_NOSORT" + mode = "value_idx_nosort" else: - mode = "VALUE_IDX_SORTED" + mode = "value_idx_sorted" op = builtin.TopK(mode=mode) if not isinstance(k, Tensor): @@ -765,8 +765,8 @@ def matmul( inp2: Tensor, transpose_a=False, transpose_b=False, - compute_mode="DEFAULT", - format="DEFAULT", + compute_mode="default", + format="default", ) -> Tensor: """ Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. @@ -776,7 +776,9 @@ def matmul( - Both 1-D tensor, simply forward to ``dot``. - Both 2-D tensor, normal matrix multiplication. - If one input tensor is 1-D, matrix vector multiplication. - - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted. For example: + - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, + the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted. + For example: - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index ea3b0ae3e122b6d3b4530c4c69a175cb4d7c08dd..4d9e27a6b07ab4af8752b73a38b213940d1e9d54 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -52,6 +52,8 @@ __all__ = [ "deformable_psroi_pooling", "dropout", "embedding", + "hsigmoid", + "hswish", "indexing_one_hot", "leaky_relu", "linear", @@ -62,17 +64,14 @@ __all__ = [ "max_pool2d", "one_hot", "prelu", - "softmax", - "softplus", - "sync_batch_norm", - "conv1d", - "sigmoid", - "hsigmoid", "relu", "relu6", - "hswish", - "resize", "remap", + "resize", + "sigmoid", + "softmax", + "softplus", + "sync_batch_norm", "warp_affine", "warp_perspective", ] @@ -106,6 +105,83 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor return ret +def conv1d( + inp: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + conv_mode="cross_correlation", + compute_mode="default", +) -> Tensor: + """1D convolution operation. + + Refer to :class:`~.Conv1d` for more information. + + :param inp: The feature map of the convolution operation + :param weight: The convolution kernel + :param bias: The bias added to the result of convolution (if given) + :param stride: Stride of the 1D convolution operation. Default: 1 + :param padding: Size of the paddings added to the input on both sides of its + spatial dimensions. Only zero-padding is supported. Default: 0 + :param dilation: Dilation of the 1D convolution operation. Default: 1 + :param groups: number of groups to divide input and output channels into, + so as to perform a "grouped convolution". When ``groups`` is not 1, + ``in_channels`` and ``out_channels`` must be divisible by ``groups``, + and the shape of weight should be ``(groups, out_channel // groups, + in_channels // groups, height, width)``. + :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` + :param conv_mode: Supports 'cross_correlation'. Default: + 'cross_correlation'. + :type compute_mode: string or + :class:`mgb.opr_param_defs.Convolution.ComputeMode` + :param compute_mode: When set to 'default', no special requirements will be + placed on the precision of intermediate results. When set to 'float32', + float32 would be used for accumulator and intermediate result, but only + effective when input and output are of float16 dtype. + + """ + assert ( + conv_mode.lower() == "cross_correlation" + or conv_mode.name == "CROSS_CORRELATION" + ) + assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" + assert inp.ndim == 3, "the input dimension of conv1d should be 3" + assert weight.ndim == 3, "the weight dimension of conv1d should be 3" + + inp = expand_dims(inp, 3) + weight = expand_dims(weight, 3) + if bias is not None: + assert bias.ndim == 3, "the bias dimension of conv1d should be 3" + bias = expand_dims(bias, 3) + + stride_h = stride + pad_h = padding + dilate_h = dilation + + sparse_type = "dense" if groups == 1 else "group" + op = builtin.Convolution( + stride_h=stride_h, + stride_w=1, + pad_h=pad_h, + pad_w=0, + dilate_h=dilate_h, + dilate_w=1, + strategy=get_execution_strategy(), + mode=conv_mode, + compute_mode=compute_mode, + sparse=sparse_type, + ) + inp, weight = utils.convert_inputs(inp, weight) + (output,) = apply(op, inp, weight) + if bias is not None: + output += bias + output = squeeze(output, 3) + return output + + def conv2d( inp: Tensor, weight: Tensor, @@ -114,8 +190,8 @@ def conv2d( padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, - conv_mode="CROSS_CORRELATION", - compute_mode="DEFAULT", + conv_mode="cross_correlation", + compute_mode="default", ) -> Tensor: """ 2D convolution operation. @@ -135,24 +211,27 @@ def conv2d( and the shape of weight should be `(groups, out_channel // groups, in_channels // groups, height, width)`. :type conv_mode: string or :class:`Convolution.Mode` - :param conv_mode: supports "CROSS_CORRELATION". Default: - "CROSS_CORRELATION" + :param conv_mode: supports "cross_correlation". Default: + "cross_correlation" :type compute_mode: string or :class:`Convolution.ComputeMode` - :param compute_mode: when set to "DEFAULT", no special requirements will be - placed on the precision of intermediate results. When set to "FLOAT32", - "Float32" would be used for accumulator and intermediate result, but only - effective when input and output are of Float16 dtype. + :param compute_mode: when set to "default", no special requirements will be + placed on the precision of intermediate results. When set to "float32", + "float32" would be used for accumulator and intermediate result, but only + effective when input and output are of float16 dtype. :return: output tensor. """ - assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" - assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" + assert ( + conv_mode.lower() == "cross_correlation" + or conv_mode.name == "CROSS_CORRELATION" + ) + assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) dilate_h, dilate_w = expand_hw(dilation) - sparse_type = "DENSE" if groups == 1 else "GROUP" + sparse_type = "dense" if groups == 1 else "group" op = builtin.Convolution( stride_h=stride_h, stride_w=stride_w, @@ -180,7 +259,7 @@ def conv3d( padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, - conv_mode: str = "CROSS_CORRELATION", + conv_mode: str = "cross_correlation", ) -> Tensor: """ 3D convolution operation. @@ -194,15 +273,16 @@ def conv3d( :param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 :param dilation: dilation of the 3D convolution operation. Default: 1 - :param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, + :param groups: number of groups into which the input and output channels are divided, + so as to perform a ``grouped convolution``. When ``groups`` is not 1, ``in_channels`` and ``out_channels`` must be divisible by ``groups``, and the shape of weight should be `(groups, out_channel // groups, in_channels // groups, t, height, width)`. - :param conv_mode: supports "CROSS_CORRELATION". Default: - "CROSS_CORRELATION" + :param conv_mode: supports "cross_correlation". Default: + "cross_correlation" :return: output tensor. """ - assert conv_mode == "CROSS_CORRELATION" + assert conv_mode.lower() == "cross_correlation" D, H, W = 0, 1, 2 @@ -210,7 +290,7 @@ def conv3d( stride = _triple_nonzero(stride) dilate = _triple_nonzero(dilation) - sparse_type = "DENSE" if groups == 1 else "GROUP" + sparse_type = "dense" if groups == 1 else "group" op = builtin.Convolution3D( pad_d=pad[D], pad_h=pad[H], @@ -240,8 +320,8 @@ def conv_transpose2d( padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, - conv_mode="CROSS_CORRELATION", - compute_mode="DEFAULT", + conv_mode="cross_correlation", + compute_mode="default", ) -> Tensor: """ 2D transposed convolution operation. @@ -261,18 +341,21 @@ def conv_transpose2d( and the shape of weight should be `(groups, out_channel // groups, in_channels // groups, height, width)`. Default: 1 :type conv_mode: string or :class:`Convolution.Mode` - :param conv_mode: supports "CROSS_CORRELATION". Default: - "CROSS_CORRELATION" + :param conv_mode: supports "cross_correlation". Default: + "cross_correlation" :type compute_mode: string or :class:`Convolution.ComputeMode` - :param compute_mode: when set to "DEFAULT", no special requirements will be - placed on the precision of intermediate results. When set to "FLOAT32", - "Float32" would be used for accumulator and intermediate result, but only - effective when input and output are of Float16 dtype. + :param compute_mode: when set to "default", no special requirements will be + placed on the precision of intermediate results. When set to "float32", + "float32" would be used for accumulator and intermediate result, but only + effective when input and output are of float16 dtype. :return: output tensor. """ - assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" - assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" + assert ( + conv_mode.lower() == "cross_correlation" + or conv_mode.name == "CROSS_CORRELATION" + ) + assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" if groups != 1: raise NotImplementedError("group transposed conv2d is not supported yet.") @@ -307,8 +390,8 @@ def deformable_conv2d( padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, - conv_mode="CROSS_CORRELATION", - compute_mode="DEFAULT", + conv_mode="cross_correlation", + compute_mode="default", ) -> Tensor: """ Deformable Convolution. @@ -328,24 +411,27 @@ def deformable_conv2d( and the shape of weight should be `(groups, out_channel // groups, in_channels // groups, height, width)`. Default: 1 :type conv_mode: string or :class:`Convolution.Mode` - :param conv_mode: supports "CROSS_CORRELATION". Default: - "CROSS_CORRELATION" + :param conv_mode: supports "cross_correlation". Default: + "cross_correlation" :type compute_mode: string or :class:`Convolution.ComputeMode` - :param compute_mode: when set to "DEFAULT", no special requirements will be - placed on the precision of intermediate results. When set to "FLOAT32", - "Float32" would be used for accumulator and intermediate result, but only - effective when input and output are of Float16 dtype. + :param compute_mode: when set to "default", no special requirements will be + placed on the precision of intermediate results. When set to "float32", + "float32" would be used for accumulator and intermediate result, but only + effective when input and output are of float16 dtype. :return: output tensor. """ - assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" - assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" + assert ( + conv_mode.lower() == "cross_correlation" + or conv_mode.name == "CROSS_CORRELATION" + ) + assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) dilate_h, dilate_w = expand_hw(dilation) - sparse_type = "DENSE" if groups == 1 else "GROUP" + sparse_type = "dense" if groups == 1 else "group" op = builtin.DeformableConv( stride_h=stride_h, stride_w=stride_w, @@ -372,10 +458,13 @@ def local_conv2d( stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, - conv_mode="CROSS_CORRELATION", + conv_mode="cross_correlation", ): """Applies spatial 2D convolution over an groupped channeled image with untied kernels.""" - assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" + assert ( + conv_mode.lower() == "cross_correlation" + or conv_mode.name == "CROSS_CORRELATION" + ) stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) @@ -389,8 +478,8 @@ def local_conv2d( dilate_h=dilate_h, dilate_w=dilate_w, mode=conv_mode, - compute_mode="DEFAULT", - sparse="DENSE", + compute_mode="default", + sparse="dense", ) inp, weight = utils.convert_inputs(inp, weight) (output,) = apply(op, inp, weight) @@ -430,7 +519,7 @@ def max_pool2d( stride_w=stride_w, pad_h=padding_h, pad_w=padding_w, - mode="MAX", + mode="max", ) (output,) = apply(op, inp) return output @@ -441,7 +530,7 @@ def avg_pool2d( kernel_size: Union[int, Tuple[int, int]], stride: Optional[Union[int, Tuple[int, int]]] = None, padding: Union[int, Tuple[int, int]] = 0, - mode: str = "AVERAGE_COUNT_EXCLUDE_PADDING", + mode: str = "average_count_exclude_padding", ) -> Tensor: """ Applies 2D average pooling over an input tensor. @@ -453,7 +542,8 @@ def avg_pool2d( :param stride: stride of the window. If not provided, its value is set to ``kernel_size``. Default: None :param padding: implicit zero padding added on both sides. Default: 0 - :param mode: whether to count padding values. Default: "AVERAGE_COUNT_EXCLUDE_PADDING" + :param mode: whether to count padding values, set to "average" will do counting. + Default: "average_count_exclude_padding" :return: output tensor. """ if stride is None: @@ -490,7 +580,7 @@ def adaptive_max_pool2d( if isinstance(oshp, int): oshp = (oshp, oshp) - op = builtin.AdaptivePooling(mode="MAX", format="NCHW",) + op = builtin.AdaptivePooling(mode="max", format="NCHW",) oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) (output,) = apply(op, inp, oshp) return output @@ -511,7 +601,7 @@ def adaptive_avg_pool2d( if isinstance(oshp, int): oshp = (oshp, oshp) - op = builtin.AdaptivePooling(mode="AVERAGE", format="NCHW",) + op = builtin.AdaptivePooling(mode="average", format="NCHW",) oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) (output,) = apply(op, inp, oshp) return output @@ -556,6 +646,53 @@ def deformable_psroi_pooling( return output +def hswish(x): + """ + Element-wise `x * relu6(x + 3) / 6`. + + :param x: input tensor. + :return: computed tensor. + + Example: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + x = tensor(np.arange(5).astype(np.float32)) + out = F.hswish(x) + print(out.numpy().round(decimals=4)) + + .. testoutput:: + + [0. 0.6667 1.6667 3. 4. ] + + """ + return _elwise(x, mode=Elemwise.Mode.H_SWISH) + + +def sigmoid(x): + """Element-wise `1 / ( 1 + exp( -x ) )`.""" + return _elwise(x, mode=Elemwise.Mode.SIGMOID) + + +def hsigmoid(x): + """Element-wise `relu6(x + 3) / 6`.""" + return relu6(x + 3) / 6 + + +def relu(x): + """Element-wise `max(x, 0)`.""" + return _elwise(x, mode=Elemwise.Mode.RELU) + + +def relu6(x): + """Element-wise `min(max(x, 0), 6)`.""" + return minimum(maximum(x, 0), 6) + + def prelu(inp: Tensor, weight: Tensor) -> Tensor: r""" Applies the element-wise PReLU function. @@ -872,14 +1009,14 @@ def batch_norm( if not training: op = builtin.BatchNorm( - fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="DIM_1C11" + fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11" ) ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] return ret else: op = builtin.BatchNorm( - avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11" + avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11" ) if has_mean or has_var: running_mean = make_full_if_none(running_mean, 0) @@ -915,7 +1052,7 @@ def sync_batch_norm( training: bool = False, momentum: Union[float, Tensor] = 0.9, eps: float = 1e-5, - eps_mode="ADDITIVE", + eps_mode="additive", group=WORLD, ) -> Tensor: r""" @@ -939,7 +1076,9 @@ def sync_batch_norm( Default: 1e-5 :return: output tensor. """ - assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) + assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( + eps_mode + ) _channels = inp.shape[1] _ndim = inp.ndim _device = inp.device @@ -979,7 +1118,7 @@ def sync_batch_norm( channel_mean = running_mean.reshape(*_param_shape) invsqrt_channel_variance = ( - maximum(channel_variance, eps) if eps_mode == "MAX" else channel_variance + eps + maximum(channel_variance, eps) if eps_mode == "max" else channel_variance + eps ) ** -0.5 if weight is not None: @@ -1019,13 +1158,16 @@ def sync_batch_norm( return outvar -def one_hot(inp: Tensor, num_classes: int) -> Tensor: - r""" - Performs one-hot encoding for the input tensor. +def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: + """ + Returns a new tensor where each of the elements are randomly set to zero + with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True. :param inp: input tensor. - :param num_classes: number of classes denotes the last dimension of the output tensor. - :return: output tensor. + :param drop_prob: probability to drop (set to zero) a single element. + :param training: the default behavior of ``dropout`` during training is to rescale the output, + then it can be replaced by an :class:`~.Identity` during inference. Default: True + :return: the output tensor Examples: @@ -1035,51 +1177,33 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: from megengine import tensor import megengine.functional as F - x = tensor(np.arange(1, 4, dtype=np.int32)) - out = F.one_hot(x, num_classes=4) + x = tensor(np.ones(10, dtype=np.float32)) + out = F.dropout(x, 1./3.) print(out.numpy()) Outputs: .. testoutput:: + :options: +SKIP - [[0 1 0 0] - [0 0 1 0] - [0 0 0 1]] - - """ - zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device) - ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device) - - op = builtin.IndexingSetOneHot(axis=inp.ndim) - (result,) = apply(op, zeros_tensor, inp, ones_tensor) - return result - + [1.5 1.5 0. 1.5 1.5 1.5 1.5 1.5 1.5 1.5] -def matmul( - inp1: Tensor, - inp2: Tensor, - transpose_a=False, - transpose_b=False, - compute_mode="DEFAULT", - format="DEFAULT", -) -> Tensor: """ - Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. + assert 0 <= drop_prob < 1 + rv = uniform(size=inp.shape) + mask = rv > drop_prob + inp *= mask.astype(inp.dtype) + if training: + inp *= 1 / (1 - drop_prob) + return inp - With different inputs dim, this function behaves differently: - - Both 1-D tensor, simply forward to ``dot``. - - Both 2-D tensor, normal matrix multiplication. - - If one input tensor is 1-D, matrix vector multiplication. - - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will - be broadcasted. For example: - - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` - - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` - - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)` +def one_hot(inp: Tensor, num_classes: int) -> Tensor: + r""" + Performs one-hot encoding for the input tensor. - :param inp1: first matrix to be multiplied. - :param inp2: second matrix to be multiplied. + :param inp: input tensor. + :param num_classes: number of classes denotes the last dimension of the output tensor. :return: output tensor. Examples: @@ -1090,182 +1214,27 @@ def matmul( from megengine import tensor import megengine.functional as F - data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) - data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2)) - out = F.matmul(data1, data2) + x = tensor(np.arange(1, 4, dtype=np.int32)) + out = F.one_hot(x, num_classes=4) print(out.numpy()) Outputs: .. testoutput:: - [[10. 13.] - [28. 40.]] - - """ - remove_row, remove_col = False, False - inp1, inp2 = utils.convert_inputs(inp1, inp2) - - dim1, dim2 = inp1.ndim, inp2.ndim - # handle dim=1 cases, dot and matrix-vector multiplication - if dim1 == 1 and dim2 == 1: - return dot(inp1, inp2) - # the underlying matmul op requires input dims to be at least 2 - if dim1 == 1: - inp1 = expand_dims(inp1, 0) - dim1 = 2 - remove_row = True - if dim2 == 1: - inp2 = expand_dims(inp2, 1) - dim2 = 2 - remove_col = True - - batch_shape = None - shape1 = inp1.shape - shape2 = inp2.shape - - maxdim = dim1 if dim1 > dim2 else dim2 - if dim1 >= 3 or dim2 >= 3: - if use_symbolic_shape(): - if dim1 > dim2: - shape2 = concat([shape1[:-2], shape2[-2:]]) - inp2 = broadcast_to(inp2, shape2) - if dim1 < dim2: - shape1 = concat([shape2[:-2], shape1[-2:]]) - inp1 = broadcast_to(inp1, shape1) - if maxdim > 3: - batch_shape = shape1[:-2] - # compress inputs to 3d - (inp1,) = apply( - builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) - ) - (inp2,) = apply( - builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) - ) - else: - if dim1 > dim2: - shape2 = shape1[:-2] + shape2[-2:] - inp2 = broadcast_to(inp2, shape2) - if dim1 < dim2: - shape1 = shape2[:-2] + shape1[-2:] - inp1 = broadcast_to(inp1, shape1) - if maxdim > 3: - batch_shape = shape1[:-2] - # compress inputs to 3d - inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) - inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) - - op = builtin.BatchedMatrixMul( - transposeA=transpose_a, - transposeB=transpose_b, - compute_mode=compute_mode, - format=format, - strategy=get_execution_strategy(), - ) - else: - op = builtin.MatrixMul( - transposeA=transpose_a, - transposeB=transpose_b, - compute_mode=compute_mode, - format=format, - strategy=get_execution_strategy(), - ) - - (result,) = apply(op, inp1, inp2) - if maxdim > 3: - if use_symbolic_shape(): - (result,) = apply( - builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) - ) - else: - result = result.reshape(batch_shape + result.shape[-2:]) - if remove_row: - result = squeeze(result, axis=-2) - if remove_col: - result = squeeze(result, axis=-1) - return result - + [[0 1 0 0] + [0 0 1 0] + [0 0 0 1]] -def dot(inp1: Tensor, inp2: Tensor) -> Tensor: """ - Computes dot-product of two vectors ``inp1`` and ``inp2``. - inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted. - Refer to :func:`~.matmul` for more general usage. - - :param inp1: first vector. - :param inp2: second vector. - :return: output value. - - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - data1 = tensor(np.arange(0, 6, dtype=np.float32)) - data2 = tensor(np.arange(0, 6, dtype=np.float32)) - out = F.dot(data1, data2) - print(out.numpy()) - - Outputs: - - .. testoutput:: - - 55. + zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device) + ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device) - """ - op = builtin.Dot() - inp1, inp2 = utils.convert_inputs(inp1, inp2) - assert ( - inp1.ndim <= 1 and inp2.ndim <= 1 - ), "Input tensors for dot must be 1-dimensional or scalar" - (result,) = apply(op, inp1, inp2) - setscalar(result) + op = builtin.IndexingSetOneHot(axis=inp.ndim) + (result,) = apply(op, zeros_tensor, inp, ones_tensor) return result -def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: - """ - Returns a new tensor where each of the elements are randomly set to zero - with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True. - - :param inp: input tensor. - :param drop_prob: probability to drop (set to zero) a single element. - :param training: the default behavior of ``dropout`` during training is to rescale the output, - then it can be replaced by an :class:`~.Identity` during inference. Default: True - :return: the output tensor - - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - x = tensor(np.ones(10, dtype=np.float32)) - out = F.dropout(x, 1./3.) - print(out.numpy()) - - Outputs: - - .. testoutput:: - :options: +SKIP - - [1.5 1.5 0. 1.5 1.5 1.5 1.5 1.5 1.5 1.5] - - """ - assert 0 <= drop_prob < 1 - rv = uniform(size=inp.shape) - mask = rv > drop_prob - inp *= mask.astype(inp.dtype) - if training: - inp *= 1 / (1 - drop_prob) - return inp - - def embedding( inp: Tensor, weight: Tensor, @@ -1334,128 +1303,6 @@ def indexing_one_hot( return result -def conv1d( - inp: Tensor, - weight: Tensor, - bias: Optional[Tensor] = None, - stride: int = 1, - padding: int = 0, - dilation: int = 1, - groups: int = 1, - conv_mode="CROSS_CORRELATION", - compute_mode="DEFAULT", -) -> Tensor: - """1D convolution operation. - - Refer to :class:`~.Conv1d` for more information. - - :param inp: The feature map of the convolution operation - :param weight: The convolution kernel - :param bias: The bias added to the result of convolution (if given) - :param stride: Stride of the 1D convolution operation. Default: 1 - :param padding: Size of the paddings added to the input on both sides of its - spatial dimensions. Only zero-padding is supported. Default: 0 - :param dilation: Dilation of the 1D convolution operation. Default: 1 - :param groups: number of groups to divide input and output channels into, - so as to perform a "grouped convolution". When ``groups`` is not 1, - ``in_channels`` and ``out_channels`` must be divisible by ``groups``, - and the shape of weight should be ``(groups, out_channel // groups, - in_channels // groups, height, width)``. - :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` - :param conv_mode: Supports 'CROSS_CORRELATION'. Default: - 'CROSS_CORRELATION'. - :type compute_mode: string or - :class:`mgb.opr_param_defs.Convolution.ComputeMode` - :param compute_mode: When set to 'DEFAULT', no special requirements will be - placed on the precision of intermediate results. When set to 'FLOAT32', - Float32 would be used for accumulator and intermediate result, but only - effective when input and output are of Float16 dtype. - - """ - - assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" - assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" - assert inp.ndim == 3, "the input dimension of conv1d should be 3" - assert weight.ndim == 3, "the weight dimension of conv1d should be 3" - - inp = expand_dims(inp, 3) - weight = expand_dims(weight, 3) - if bias is not None: - assert bias.ndim == 3, "the bias dimension of conv1d should be 3" - bias = expand_dims(bias, 3) - - stride_h = stride - pad_h = padding - dilate_h = dilation - - sparse_type = "DENSE" if groups == 1 else "GROUP" - op = builtin.Convolution( - stride_h=stride_h, - stride_w=1, - pad_h=pad_h, - pad_w=0, - dilate_h=dilate_h, - dilate_w=1, - strategy=get_execution_strategy(), - mode=conv_mode, - compute_mode=compute_mode, - sparse=sparse_type, - ) - inp, weight = utils.convert_inputs(inp, weight) - (output,) = apply(op, inp, weight) - if bias is not None: - output += bias - output = squeeze(output, 3) - return output - - -def hswish(x): - """ - Element-wise `x * relu6(x + 3) / 6`. - - :param x: input tensor. - :return: computed tensor. - - Example: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - x = tensor(np.arange(5).astype(np.float32)) - out = F.hswish(x) - print(out.numpy().round(decimals=4)) - - .. testoutput:: - - [0. 0.6667 1.6667 3. 4. ] - - """ - return _elwise(x, mode=Elemwise.Mode.H_SWISH) - - -def sigmoid(x): - """Element-wise `1 / ( 1 + exp( -x ) )`.""" - return _elwise(x, mode=Elemwise.Mode.SIGMOID) - - -def hsigmoid(x): - """Element-wise `relu6(x + 3) / 6`.""" - return relu6(x + 3) / 6 - - -def relu(x): - """Element-wise `max(x, 0)`.""" - return _elwise(x, mode=Elemwise.Mode.RELU) - - -def relu6(x): - """Element-wise `min(max(x, 0), 6)`.""" - return minimum(maximum(x, 0), 6) - - interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True) roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True) roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True) diff --git a/imperative/python/megengine/functional/quantized.py b/imperative/python/megengine/functional/quantized.py index 53a00d626987d6cdf1eebd82e0860da800b90564..17a45d04de214f2b092a44f4f259dac25ed28ceb 100644 --- a/imperative/python/megengine/functional/quantized.py +++ b/imperative/python/megengine/functional/quantized.py @@ -24,9 +24,9 @@ def conv_bias_activation( padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, - nonlinear_mode="IDENTITY", - conv_mode="CROSS_CORRELATION", - compute_mode="DEFAULT", + nonlinear_mode="identity", + conv_mode="cross_correlation", + compute_mode="default", ) -> Tensor: """ Convolution bias with activation operation, only for inference. @@ -35,27 +35,30 @@ def conv_bias_activation( :param weight: convolution kernel. :param bias: bias added to the result of convolution :param stride: stride of the 2D convolution operation. Default: 1 - :param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 + :param padding: size of the paddings added to the input on both sides + of its spatial dimensions. Only zero-padding is supported. Default: 0 :param dilation: dilation of the 2D convolution operation. Default: 1 - :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, + :param groups: number of groups into which the input and output channels are divided, + so as to perform a "grouped convolution". When ``groups`` is not 1, ``in_channels`` and ``out_channels`` must be divisible by ``groups``, and the shape of weight should be `(groups, out_channel // groups, in_channels // groups, height, width)`. :type conv_mode: string or :class:`Convolution.Mode`. - :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: - 'CROSS_CORRELATION' + :param conv_mode: supports 'cross_correlation' or 'convolution'. Default: + 'cross_correlation' :param dtype: support for ``np.dtype``, Default: np.int8 :type compute_mode: string or :class:`Convolution.ComputeMode`. - :param compute_mode: when set to "DEFAULT", no special requirements will be - placed on the precision of intermediate results. When set to "FLOAT32", - "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. + :param compute_mode: when set to "default", no special requirements will be + placed on the precision of intermediate results. When set to "float32", + "float32" would be used for accumulator and intermediate result, + but only effective when input and output are of float16 dtype. """ ph, pw = _pair(padding) sh, sw = _pair_nonzero(stride) dh, dw = _pair_nonzero(dilation) - sparse_type = "DENSE" if groups == 1 else "GROUP" + sparse_type = "dense" if groups == 1 else "group" op = builtin.ConvBias( stride_h=sh, stride_w=sw, @@ -84,9 +87,9 @@ def batch_conv_bias_activation( padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, - nonlinear_mode="IDENTITY", - conv_mode="CROSS_CORRELATION", - compute_mode="DEFAULT", + nonlinear_mode="identity", + conv_mode="cross_correlation", + compute_mode="default", ) -> Tensor: """ Batch convolution bias with activation operation, only for inference. @@ -95,27 +98,30 @@ def batch_conv_bias_activation( :param weight: convolution kernel in batched way. :param bias: bias added to the result of convolution :param stride: stride of the 2D convolution operation. Default: 1 - :param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 + :param padding: size of the paddings added to the input on both sides + of its spatial dimensions. Only zero-padding is supported. Default: 0 :param dilation: dilation of the 2D convolution operation. Default: 1 - :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, + :param groups: number of groups into which the input and output channels are divided, + so as to perform a "grouped convolution". When ``groups`` is not 1, ``in_channels`` and ``out_channels`` must be divisible by ``groups``, and the shape of weight should be `(groups, out_channel // groups, in_channels // groups, height, width)`. :type conv_mode: string or :class:`Convolution.Mode`. - :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: - 'CROSS_CORRELATION' + :param conv_mode: supports 'cross_correlation' or 'convolution'. Default: + 'cross_correlation' :param dtype: support for ``np.dtype``, Default: np.int8 :type compute_mode: string or :class:`Convolution.ComputeMode`. - :param compute_mode: when set to "DEFAULT", no special requirements will be - placed on the precision of intermediate results. When set to "FLOAT32", - "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. + :param compute_mode: when set to "default", no special requirements will be + placed on the precision of intermediate results. When set to "float32", + "float32" would be used for accumulator and intermediate result, + but only effective when input and output are of float16 dtype. """ ph, pw = _pair(padding) sh, sw = _pair_nonzero(stride) dh, dw = _pair_nonzero(dilation) - sparse_type = "DENSE" if groups == 1 else "GROUP" + sparse_type = "dense" if groups == 1 else "group" op = builtin.BatchConvBias( stride_h=sh, stride_w=sw, diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 0b45ec9c97711b33f394af066d64274e2bdb508d..6b50027be0a9fec738e5f5c2f2ffdd72e9eb2101 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -335,12 +335,8 @@ def split(inp, nsplits_or_sections, axis=0): y = F.split(x, 3) z = F.split(x, [6, 17], axis=1) - if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"): - print([tuple(i.shape.numpy().tolist()) for i in y]) - print([tuple(i.shape.numpy().tolist()) for i in z]) - else: - print([i.shape for i in y]) - print([i.shape for i in z]) + print([i.numpy().shape for i in y]) + print([i.numpy().shape for i in z]) Outputs: diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index b2a97586930d22a9185f920a15fad2e9d1cd3a94..07d8c3aa69545afcfe0257cc0fc15d6eeb9ca5f5 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -46,6 +46,7 @@ def cvt_color(inp: Tensor, mode: str = ""): [[[[0.86555195]]]] """ + mode = mode.upper() assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color" mode = getattr(builtin.CvtColor.Mode, mode) assert isinstance(mode, builtin.CvtColor.Mode) @@ -92,9 +93,8 @@ def roi_pooling( [[[-0.1383 -0.1383] [-0.5035 -0.5035]]] - """ - assert mode in ["max", "average"], "only max/average mode is supported" + assert mode.lower() in ["max", "average"], "only max/average mode is supported" if isinstance(output_shape, int): output_shape = (output_shape, output_shape) @@ -151,6 +151,7 @@ def roi_align( [0.1359 0.1359]]] """ + mode = mode.lower() assert mode in ["max", "average"], "only max/average mode is supported" if isinstance(output_shape, int): output_shape = (output_shape, output_shape) @@ -244,9 +245,9 @@ def nms( def remap( inp: Tensor, map_xy: Tensor, - border_mode: str = "REPLICATE", + border_mode: str = "replicate", scalar: float = 0.0, - interp_mode: str = "LINEAR", + interp_mode: str = "linear", ) -> Tensor: r""" Applies remap transformation to batched 2D images. @@ -257,11 +258,11 @@ def remap( :param inp: input image :param map_xy: (batch, oh, ow, 2) transformation matrix :param border_mode: pixel extrapolation method. - Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT", - "REFLECT_101", "WRAP". + Default: "replicate". Currently also support "constant", "reflect", + "reflect_101", "wrap". :param scalar: value used in case of a constant border. Default: 0 :param interp_mode: interpolation methods. - Default: "LINEAR". Currently only support "LINEAR" mode. + Default: "linear". Currently only support "linear" mode. :return: output tensor. Examples: @@ -301,10 +302,10 @@ def warp_affine( inp: Tensor, weight: Tensor, out_shape, - border_mode="REPLICATE", + border_mode="replicate", border_val=0, format="NHWC", - imode="LINEAR", + imode="linear", ): """ Batched affine transform on 2D images. @@ -313,13 +314,13 @@ def warp_affine( :param weight: weight tensor. :param out_shape: output tensor shape. :param border_mode: pixel extrapolation method. - Default: "WRAP". Currently "CONSTANT", "REFLECT", - "REFLECT_101", "ISOLATED", "WRAP", "REPLICATE", "TRANSPARENT" are supported. + Default: "wrap". Currently "constant", "reflect", + "reflect_101", "isolated", "wrap", "replicate", "transparent" are supported. :param border_val: value used in case of a constant border. Default: 0 :param format: "NHWC" as default based on historical concerns, - "NCHW" is also supported. Default: "NCHW". - :param imode: interpolation methods. Could be "LINEAR", "NEAREST", "CUBIC", "AREA". - Default: "LINEAR". + "NCHW" is also supported. Default: "NHWC". + :param imode: interpolation methods. Could be "linear", "nearest", "cubic", "area". + Default: "linear". :return: output tensor. .. note:: @@ -340,9 +341,9 @@ def warp_perspective( inp: Tensor, M: Tensor, dsize: Union[Tuple[int, int], int, Tensor], - border_mode: str = "REPLICATE", + border_mode: str = "replicate", border_val: float = 0.0, - interp_mode: str = "LINEAR", + interp_mode: str = "linear", ) -> Tensor: r""" Applies perspective transformation to batched 2D images. @@ -359,11 +360,11 @@ def warp_perspective( :param M: `(batch, 3, 3)` transformation matrix. :param dsize: `(h, w)` size of the output image. :param border_mode: pixel extrapolation method. - Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT", - "REFLECT_101", "WRAP". + Default: "replicate". Currently also support "constant", "reflect", + "reflect_101", "wrap". :param border_val: value used in case of a constant border. Default: 0 :param interp_mode: interpolation methods. - Default: "LINEAR". Currently only support "LINEAR" mode. + Default: "linear". Currently only support "linear" mode. :return: output tensor. Note: @@ -409,7 +410,7 @@ def interpolate( inp: Tensor, size: Optional[Union[int, Tuple[int, int]]] = None, scale_factor: Optional[Union[float, Tuple[float, float]]] = None, - mode: str = "BILINEAR", + mode: str = "bilinear", align_corners: Optional[bool] = None, ) -> Tensor: r""" @@ -419,9 +420,9 @@ def interpolate( :param size: size of the output tensor. Default: None :param scale_factor: scaling factor of the output tensor. Default: None :param mode: interpolation methods, acceptable values are: - "BILINEAR", "LINEAR". Default: "BILINEAR" + "bilinear", "linear". Default: "bilinear" :param align_corners: This only has an effect when `mode` - is "BILINEAR" or "LINEAR". Geometrically, we consider the pixels of the input + is "bilinear" or "linear". Geometrically, we consider the pixels of the input and output as squares rather than points. If set to ``True``, the input and output tensors are aligned by the center points of their corner pixels, preserving the values at the corner pixels. If set to ``False``, @@ -455,10 +456,10 @@ def interpolate( [3. 3.25 3.75 4. ]]]] """ - mode = mode.upper() - if mode not in ["BILINEAR", "LINEAR"]: + mode = mode.lower() + if mode not in ["bilinear", "linear"]: raise ValueError("interpolate only support linear or bilinear mode") - if mode not in ["BILINEAR", "LINEAR"]: + if mode not in ["bilinear", "linear"]: if align_corners is not None: raise ValueError( "align_corners option can only be set in the bilinear/linear interpolating mode" @@ -471,16 +472,16 @@ def interpolate( size is not None and scale_factor is None and not align_corners - and mode == "BILINEAR" + and mode == "bilinear" and inp.ndim in [4, 5] ): # fastpath for interpolate - op = builtin.Resize(imode="LINEAR", format="NCHW") + op = builtin.Resize(imode="linear", format="NCHW") shape = astensor1d(size, inp, dtype="int32", device=inp.device) (result,) = apply(op, inp, shape) return result - if mode == "LINEAR": + if mode == "linear": inp = expand_dims(inp, 3) if inp.ndim != 4: @@ -492,14 +493,14 @@ def interpolate( if isinstance(scale_factor, (float, int)): scale_factor = float(scale_factor) - if mode == "LINEAR": + if mode == "linear": scale_factor = (scale_factor, float(1)) else: scale_factor = (scale_factor, scale_factor) else: - if mode == "LINEAR": + if mode == "linear": raise ValueError( - "under LINEAR mode, scale_factor can only be single value" + "under linear mode, scale_factor can only be single value" ) assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )" @@ -524,8 +525,8 @@ def interpolate( if isinstance(size, int): size = (size, 1) else: - if mode == "LINEAR": - raise ValueError("under LINEAR mode, size can only be single value") + if mode == "linear": + raise ValueError("under linear mode, size can only be single value") dsize = size oh, ow = dsize[0], dsize[1] @@ -534,7 +535,7 @@ def interpolate( if align_corners: hscale = (ih - 1.0) / (oh - 1.0) wscale = 1.0 * iw / ow - if mode != "LINEAR": + if mode != "linear": wscale = (iw - 1.0) / (ow - 1.0) row0 = concat( [wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0 @@ -570,8 +571,8 @@ def interpolate( weight = broadcast_to(weight, (inp.shape[0], 3, 3)) weight = weight.astype("float32") - ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") - if mode == "LINEAR": + ret = warp_perspective(inp, weight, dsize, interp_mode="linear") + if mode == "linear": ret = reshape(ret, ret.shape[0:3]) return ret diff --git a/imperative/python/megengine/module/batch_matmul_activation.py b/imperative/python/megengine/module/batch_matmul_activation.py index 4f5c8ac62272c22fdae03441a68e1b33fe8e5fbc..301f0a7215ec01d46ac83f00277c05a273b6a6cc 100644 --- a/imperative/python/megengine/module/batch_matmul_activation.py +++ b/imperative/python/megengine/module/batch_matmul_activation.py @@ -24,7 +24,7 @@ class BatchMatMulActivation(Module): in_features: int, out_features: int, bias: bool = True, - nonlinear_mode="IDENTITY", + nonlinear_mode="identity", **kwargs ): super().__init__(**kwargs) @@ -37,7 +37,7 @@ class BatchMatMulActivation(Module): if bias: b_shape = (out_features,) self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) - self.nonlinear_mode = nonlinear_mode + self.nonlinear_mode = nonlinear_mode.lower() self.reset_parameters() def _get_fanin(self): @@ -54,7 +54,7 @@ class BatchMatMulActivation(Module): res = matmul(weight, x) if self.bias is not None: res += bias - if self.nonlinear_mode == "RELU": + if self.nonlinear_mode == "relu": res = relu(res) return res diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index 4c8b99e8dbec4a8ceefc080401d5134c4ed2a270..1732f2d3e4bebf190c53da0810d2ea656956fc39 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -138,11 +138,11 @@ class Conv1d(_ConvNd): out_channel // groups, in_channels // groups, *kernel_size)`. :param bias: whether to add a bias onto the result of convolution. Default: True - :param conv_mode: Supports `CROSS_CORRELATION`. Default: - `CROSS_CORRELATION` - :param compute_mode: When set to "DEFAULT", no special requirements will be - placed on the precision of intermediate results. When set to "FLOAT32", - "Float32" would be used for accumulator and intermediate result, but only + :param conv_mode: Supports `cross_correlation`. Default: + `cross_correlation` + :param compute_mode: When set to "default", no special requirements will be + placed on the precision of intermediate results. When set to "float32", + "float32" would be used for accumulator and intermediate result, but only effective when input and output are of float16 dtype. Examples: @@ -176,8 +176,8 @@ class Conv1d(_ConvNd): dilation: int = 1, groups: int = 1, bias: bool = True, - conv_mode: str = "CROSS_CORRELATION", - compute_mode: str = "DEFAULT", + conv_mode: str = "cross_correlation", + compute_mode: str = "default", **kwargs ): kernel_size = kernel_size @@ -298,11 +298,11 @@ class Conv2d(_ConvNd): out_channel // groups, in_channels // groups, *kernel_size)`. :param bias: whether to add a bias onto the result of convolution. Default: True - :param conv_mode: Supports `CROSS_CORRELATION`. Default: - `CROSS_CORRELATION` - :param compute_mode: When set to "DEFAULT", no special requirements will be - placed on the precision of intermediate results. When set to "FLOAT32", - "Float32" would be used for accumulator and intermediate result, but only + :param conv_mode: Supports `cross_correlation`. Default: + `cross_correlation` + :param compute_mode: When set to "default", no special requirements will be + placed on the precision of intermediate results. When set to "float32", + "float32" would be used for accumulator and intermediate result, but only effective when input and output are of float16 dtype. Examples: @@ -336,8 +336,8 @@ class Conv2d(_ConvNd): dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, - conv_mode: str = "CROSS_CORRELATION", - compute_mode: str = "DEFAULT", + conv_mode: str = "cross_correlation", + compute_mode: str = "default", **kwargs ): kernel_size = _pair_nonzero(kernel_size) @@ -436,15 +436,16 @@ class Conv3d(_ConvNd): :param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 :param dilation: dilation of the 3D convolution operation. Default: 1 - :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, + :param groups: number of groups into which the input and output channels are divided, + so as to perform a "grouped convolution". When ``groups`` is not 1, ``in_channels`` and ``out_channels`` must be divisible by ``groups``, and there would be an extra dimension at the beginning of the weight's shape. Specifically, the shape of weight would be `(groups, out_channel // groups, in_channels // groups, *kernel_size)`. :param bias: whether to add a bias onto the result of convolution. Default: True - :param conv_mode: Supports `CROSS_CORRELATION`. Default: - `CROSS_CORRELATION` + :param conv_mode: Supports `cross_correlation`. Default: + `cross_correlation` Examples: @@ -477,7 +478,7 @@ class Conv3d(_ConvNd): dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, bias: bool = True, - conv_mode: str = "CROSS_CORRELATION", + conv_mode: str = "cross_correlation", ): kernel_size = _triple_nonzero(kernel_size) stride = _triple_nonzero(stride) @@ -566,11 +567,11 @@ class ConvTranspose2d(_ConvNd): out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 :param bias: wether to add a bias onto the result of convolution. Default: True - :param conv_mode: Supports `CROSS_CORRELATION`. Default: - `CROSS_CORRELATION` - :param compute_mode: When set to "DEFAULT", no special requirements will be - placed on the precision of intermediate results. When set to "FLOAT32", - "Float32" would be used for accumulator and intermediate result, but only + :param conv_mode: Supports `cross_correlation`. Default: + `cross_correlation` + :param compute_mode: When set to "default", no special requirements will be + placed on the precision of intermediate results. When set to "float32", + "float32" would be used for accumulator and intermediate result, but only effective when input and output are of float16 dtype. """ @@ -584,8 +585,8 @@ class ConvTranspose2d(_ConvNd): dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, - conv_mode: str = "CROSS_CORRELATION", - compute_mode: str = "DEFAULT", + conv_mode: str = "cross_correlation", + compute_mode: str = "default", **kwargs ): kernel_size = _pair_nonzero(kernel_size) @@ -679,7 +680,7 @@ class LocalConv2d(Conv2d): padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, - conv_mode: str = "CROSS_CORRELATION", + conv_mode: str = "cross_correlation", **kwargs ): self.input_height = input_height @@ -758,11 +759,11 @@ class DeformableConv2d(_ConvNd): out_channel // groups, in_channels // groups, *kernel_size)`. :param bias: whether to add a bias onto the result of convolution. Default: True - :param conv_mode: Supports `CROSS_CORRELATION`. Default: - `CROSS_CORRELATION` - :param compute_mode: When set to "DEFAULT", no special requirements will be - placed on the precision of intermediate results. When set to "FLOAT32", - "Float32" would be used for accumulator and intermediate result, but only + :param conv_mode: Supports `cross_correlation`. Default: + `cross_correlation` + :param compute_mode: When set to "default", no special requirements will be + placed on the precision of intermediate results. When set to "float32", + "float32" would be used for accumulator and intermediate result, but only effective when input and output are of float16 dtype. """ @@ -776,8 +777,8 @@ class DeformableConv2d(_ConvNd): dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, - conv_mode: str = "CROSS_CORRELATION", - compute_mode: str = "DEFAULT", + conv_mode: str = "cross_correlation", + compute_mode: str = "default", **kwargs ): kernel_size = _pair_nonzero(kernel_size) diff --git a/imperative/python/megengine/module/conv_bn.py b/imperative/python/megengine/module/conv_bn.py index 390ec3b11c3407ab5000da141c74223451fc64eb..d27a292ef06ead67b3133f0c8ccc74c02f2ef6b0 100644 --- a/imperative/python/megengine/module/conv_bn.py +++ b/imperative/python/megengine/module/conv_bn.py @@ -24,8 +24,8 @@ class _ConvBnActivation2d(Module): dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, - conv_mode: str = "CROSS_CORRELATION", - compute_mode: str = "DEFAULT", + conv_mode: str = "cross_correlation", + compute_mode: str = "default", eps=1e-5, momentum=0.9, affine=True, diff --git a/imperative/python/megengine/module/elemwise.py b/imperative/python/megengine/module/elemwise.py index a3879ab8d40ff6a418973e32d0f138128c4f1975..8f935227c9f75430c81f29ef4cf68653d27df486 100644 --- a/imperative/python/megengine/module/elemwise.py +++ b/imperative/python/megengine/module/elemwise.py @@ -18,58 +18,58 @@ class Elemwise(Module): :param method: the elemwise method, support the following string. It will do the normal elemwise operator for float. - * "ADD": a + b - * "FUSE_ADD_RELU": max(x+y, 0) - * "MUL": x * y - * "MIN": min(x, y) - * "MAX": max(x, y) - * "SUB": x - y - * "TRUE_DIV": x / y - * "FUSE_ADD_SIGMOID": sigmoid(x + y) - * "FUSE_ADD_TANH": tanh(x + y) - * "RELU": x > 0 ? x : 0 - * "ABS": x > 0 ? x : -x - * "SIGMOID": sigmoid(x) - * "EXP": exp(x) - * "TANH": tanh(x) - * "FUSE_MUL_ADD3": x * y + z - * "FAST_TANH": x * (27. + x * x) / (27. + 9. * x * x) - * "NEGATE": -x - * "ACOS": acos(x) - * "ASIN": asin(x) - * "CEIL": ceil(x) - * "COS": cos(x) - * "EXPM1": expm1(x) - * "FLOOR": floor(x) - * "LOG": log(x) - * "LOG1P": log1p(x) - * "SIN": sin(x) - * "ROUND": round(x) - * "ERF": erf(x) - * "ERFINV": erfinv(x) - * "ERFC": erfc(x) - * "ERFCINV": erfcinv(x) - * "ABS_GRAD": abs_grad - * "FLOOR_DIV": floor_div - * "MOD": mod - * "SIGMOID_GRAD": sigmoid_grad - * "SWITCH_GT0": switch_gt0 - * "TANH_GRAD": tanh_grad - * "LT": less - * "LEQ": leq - * "EQ": equal - * "POW": pow - * "LOG_SUM_EXP": log_sum_exp - * "FAST_TANH_GRAD": fast_tanh_grad - * "ATAN2": atan2 - * "COND_LEQ_MOV": cond_leq_mov - * "H_SWISH": h_swish - * "FUSE_ADD_H_SWISH": h_swish(x+y) - * "H_SWISH_GRAD": h_swish_grad - * "AND": bool binary: x && y - * "OR": bool binary: x || y - * "XOR": bool binary: x ^ y - * "NOT": bool unary: ~x + * "add": a + b + * "fuse_add_relu": max(x+y, 0) + * "mul": x * y + * "min": min(x, y) + * "max": max(x, y) + * "sub": x - y + * "true_div": x / y + * "fuse_add_sigmoid": sigmoid(x + y) + * "fuse_add_tanh": tanh(x + y) + * "relu": x > 0 ? x : 0 + * "abs": x > 0 ? x : -x + * "sigmoid": sigmoid(x) + * "exp": exp(x) + * "tanh": tanh(x) + * "fuse_mul_add3": x * y + z + * "fast_tanh": x * (27. + x * x) / (27. + 9. * x * x) + * "negate": -x + * "acos": acos(x) + * "asin": asin(x) + * "ceil": ceil(x) + * "cos": cos(x) + * "expm1": expm1(x) + * "floor": floor(x) + * "log": log(x) + * "log1p": log1p(x) + * "sin": sin(x) + * "round": round(x) + * "erf": erf(x) + * "erfinv": erfinv(x) + * "erfc": erfc(x) + * "erfcinv": erfcinv(x) + * "abs_grad": abs_grad + * "floor_div": floor_div + * "mod": mod + * "sigmoid_grad": sigmoid_grad + * "switch_gt0": switch_gt0 + * "tanh_grad": tanh_grad + * "lt": less + * "leq": leq + * "eq": equal + * "pow": pow + * "log_sum_exp": log_sum_exp + * "fast_tanh_grad": fast_tanh_grad + * "atan2": atan2 + * "cond_leq_mov": cond_leq_mov + * "h_swish": h_swish + * "fuse_add_h_swish": h_swish(x+y) + * "h_swish_grad": h_swish_grad + * "and": bool binary: x && y + * "or": bool binary: x || y + * "xor": bool binary: x ^ y + * "not": bool unary: ~x """ def __init__(self, method, **kwargs): diff --git a/imperative/python/megengine/module/quantized/batch_matmul_activation.py b/imperative/python/megengine/module/quantized/batch_matmul_activation.py index e115c1463fa3e0f95723d2271f1846037df15905..a403fff2faf5228247afe9d432d4c04d5bcacd66 100644 --- a/imperative/python/megengine/module/quantized/batch_matmul_activation.py +++ b/imperative/python/megengine/module/quantized/batch_matmul_activation.py @@ -27,7 +27,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule): in_features: int, out_features: int, bias: bool = True, - nonlinear_mode="IDENTITY", + nonlinear_mode="identity", dtype=None, **kwargs ): diff --git a/imperative/python/megengine/module/quantized/conv.py b/imperative/python/megengine/module/quantized/conv.py index 0b2ad2fa88ae67b6640ba4e163d4e1b1e94eefc4..0230dda5d0bda6979ca0c84b65a2e62d3cd8f498 100644 --- a/imperative/python/megengine/module/quantized/conv.py +++ b/imperative/python/megengine/module/quantized/conv.py @@ -34,8 +34,8 @@ class Conv2d(Float.Conv2d, QuantizedModule): padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, - conv_mode: str = "CROSS_CORRELATION", - compute_mode: str = "DEFAULT", + conv_mode: str = "cross_correlation", + compute_mode: str = "default", dtype=None, **kwargs ): @@ -53,7 +53,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): ) self.output_dtype = dtype - def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): + def calc_conv_quantized(self, inp, nonlinear_mode="identity"): inp_scale = dtype.get_scale(inp.dtype) w_scale = dtype.get_scale(self.weight.dtype) bias_scale = inp_scale * w_scale @@ -100,11 +100,11 @@ class Conv2d(Float.Conv2d, QuantizedModule): return qconv def forward(self, inp): - return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") + return self.calc_conv_quantized(inp, nonlinear_mode="identity") class ConvRelu2d(Conv2d): r"""Quantized version of :class:`~.qat.ConvRelu2d`.""" def forward(self, inp): - return self.calc_conv_quantized(inp, nonlinear_mode="RELU") + return self.calc_conv_quantized(inp, nonlinear_mode="relu") diff --git a/imperative/python/megengine/module/quantized/conv_bn.py b/imperative/python/megengine/module/quantized/conv_bn.py index 2bfc070a49c4c250613e606f70c78e05972ee5dc..e17f89e164f29a930c4fe5386eca720bb9f18dc8 100644 --- a/imperative/python/megengine/module/quantized/conv_bn.py +++ b/imperative/python/megengine/module/quantized/conv_bn.py @@ -50,11 +50,11 @@ class ConvBn2d(_ConvBnActivation2d): r"""Quantized version of :class:`~.qat.ConvBn2d`.""" def forward(self, inp): - return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") + return self.calc_conv_quantized(inp, nonlinear_mode="identity") class ConvBnRelu2d(_ConvBnActivation2d): r"""Quantized version of :class:`~.qat.ConvBnRelu2d`.""" def forward(self, inp): - return self.calc_conv_quantized(inp, nonlinear_mode="RELU") + return self.calc_conv_quantized(inp, nonlinear_mode="relu") diff --git a/imperative/python/megengine/module/quantized/elemwise.py b/imperative/python/megengine/module/quantized/elemwise.py index 46950c8f33a2d9f5ff67b451e7cc6774771b7aab..95ff3146b201b91dc63531b8a916420a4f55f479 100644 --- a/imperative/python/megengine/module/quantized/elemwise.py +++ b/imperative/python/megengine/module/quantized/elemwise.py @@ -16,7 +16,7 @@ class Elemwise(QuantizedModule): def __init__(self, method, dtype=None, **kwargs): super().__init__(**kwargs) - self.method = "Q" + method + self.method = "q" + method self.output_dtype = dtype def forward(self, *inps): diff --git a/imperative/python/test/run.sh b/imperative/python/test/run.sh index 6acc80d06c92e0d987d3e6358a04bec5bf430645..f70c121f97f79df8569697fe31db3872385c993b 100755 --- a/imperative/python/test/run.sh +++ b/imperative/python/test/run.sh @@ -16,9 +16,9 @@ fi export MEGENGINE_LOGGING_LEVEL="ERROR" pushd $(dirname "${BASH_SOURCE[0]}")/.. >/dev/null - PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest $test_dirs -m 'not isolated_distributed' + PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed' if [[ "$TEST_PLAT" == cuda ]]; then echo "test GPU pytest now" - PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest $test_dirs -m 'isolated_distributed' + PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed' fi popd >/dev/null diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index b95359b51b7206eb59ace98644c8a4017a4e4399..b5b4bea4d7853dd176a4933200122ee5a48b25e7 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -372,7 +372,7 @@ def test_interpolate_fastpath(): x = mge.Tensor(x_np) grad = Grad().wrt(x, callback=save_to(x)) - y = F.vision.interpolate(x, size=(16, 16), mode="BILINEAR") + y = F.vision.interpolate(x, size=(16, 16), mode="bilinear") grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy()) diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 7adfdfc97133a8602b3d26fa25373f161fe9ce43..3c6d43a595a9a08c06ad6bf9ffa920175d191270 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -162,7 +162,7 @@ def test_qadd(): x = tensor(x, dtype=dtype.qint8(inp_scale)) y = tensor(y, dtype=dtype.qint8(inp_scale)) result_mge = F.elemwise._elemwise_multi_type( - x, y, mode="QADD", dtype=dtype.qint8(outp_scale) + x, y, mode="qadd", dtype=dtype.qint8(outp_scale) ) result_mge = result_mge.astype("float32").numpy() result_expect = x.astype("float32").numpy() + y.astype("float32").numpy() diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 067dbbc75a5e5dae64e98931347a05ac4f51bda5..684599e02f5b3231816fac51392df2a05d4e82b7 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -140,8 +140,8 @@ def test_interpolate(): def linear_interpolate(): inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) - out = F.vision.interpolate(inp, scale_factor=2.0, mode="LINEAR") - out2 = F.vision.interpolate(inp, 4, mode="LINEAR") + out = F.vision.interpolate(inp, scale_factor=2.0, mode="linear") + out2 = F.vision.interpolate(inp, 4, mode="linear") np.testing.assert_allclose( out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32) @@ -170,13 +170,13 @@ def test_interpolate(): inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) with pytest.raises(ValueError): - F.vision.interpolate(inp, scale_factor=2.0, mode="LINEAR") + F.vision.interpolate(inp, scale_factor=2.0, mode="linear") def inappropriate_scale_linear_interpolate(): inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) with pytest.raises(ValueError): - F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="LINEAR") + F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="linear") linear_interpolate() many_batch_interpolate() @@ -339,18 +339,18 @@ def test_interpolate_fastpath(): ] for inp_shape, target_shape in test_cases: x = tensor(np.random.randn(*inp_shape), dtype=np.float32) - out = F.vision.interpolate(x, target_shape, mode="BILINEAR") + out = F.vision.interpolate(x, target_shape, mode="bilinear") assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1] assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1] # check value x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32) - out = F.vision.interpolate(x, (15, 5), mode="BILINEAR") + out = F.vision.interpolate(x, (15, 5), mode="bilinear") np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32)) np_x = np.arange(32) x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1) - out = F.vision.interpolate(x, (1, 1), mode="BILINEAR") + out = F.vision.interpolate(x, (1, 1), mode="bilinear") np.testing.assert_equal(out.item(), np_x.mean()) @@ -374,7 +374,7 @@ def test_warp_affine(): inp_shape = (1, 3, 3, 3) x = tensor(np.arange(27, dtype=np.float32).reshape(inp_shape)) weightv = [[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]] - outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="WRAP") + outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="wrap") res = np.array( [ [ @@ -509,7 +509,7 @@ def test_conv_bias(): SH, SW, has_bias=True, - nonlinear_mode="IDENTITY", + nonlinear_mode="identity", ): inp_v = np.random.normal(size=(N, IC, IH, IW)) w_v = np.random.normal(size=(OC, IC, KH, KW)) @@ -541,7 +541,7 @@ def test_conv_bias(): O = F.conv2d( inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), ) - if nonlinear_mode == "RELU": + if nonlinear_mode == "relu": return F.relu(O) else: return O @@ -583,8 +583,8 @@ def test_conv_bias(): run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) - run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU") - run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu") + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") @pytest.mark.skipif( diff --git a/imperative/python/test/unit/module/test_elemwise.py b/imperative/python/test/unit/module/test_elemwise.py index 4e400797531d9785ca31eac11dfb8844498fb1f4..009de1f042f754deb83c967a5aa5855ca767696a 100644 --- a/imperative/python/test/unit/module/test_elemwise.py +++ b/imperative/python/test/unit/module/test_elemwise.py @@ -23,8 +23,8 @@ def test_module_elemwise(): y = np.random.rand(100).astype("float32") x, y = tensor(x), tensor(y) np.testing.assert_almost_equal( - test_func("H_SWISH", x), F.hswish(x).numpy(), decimal=6 + test_func("h_swish", x), F.hswish(x).numpy(), decimal=6 ) np.testing.assert_almost_equal( - test_func("ADD", x, y), F.add(x, y).numpy(), decimal=6 + test_func("add", x, y), F.add(x, y).numpy(), decimal=6 ) diff --git a/imperative/python/test/unit/quantization/test_module.py b/imperative/python/test/unit/quantization/test_module.py index bbb95cff6eb92a69e4eb21ef290865c86bea3ab1..542afea77914f88f243bcaaa862e0accdeab6ac9 100644 --- a/imperative/python/test/unit/quantization/test_module.py +++ b/imperative/python/test/unit/quantization/test_module.py @@ -133,7 +133,7 @@ def test_dequant_stub(): np.testing.assert_allclose(q, fake_quant_normal.numpy()) -@pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"]) +@pytest.mark.parametrize("kind", ["cos", "relu", "add", "mul", "fuse_add_relu"]) def test_elemwise(kind): normal_net = Float.Elemwise(kind) normal_net.eval() @@ -167,7 +167,7 @@ def test_elemwise(kind): x2_int8 = quant(x2, x2_scale) # test correctness of `Float`, `QAT` and `Quantized` - if kind in ("ADD", "MUL", "FUSE_ADD_RELU"): + if kind in ("add", "mul", "fuse_add_relu"): normal = normal_net(x1, x2) qat_without_fakequant = qat_from_float(x1, x2) fake_quant_normal = fake_quant_act(normal_net(x1, x2), act_scale) diff --git a/imperative/python/test/unit/quantization/test_op.py b/imperative/python/test/unit/quantization/test_op.py index 2095565622794ef07af84649fd9711ddb4ac4155..53500751ab574c22da062e373e21cf831a670d3a 100644 --- a/imperative/python/test/unit/quantization/test_op.py +++ b/imperative/python/test/unit/quantization/test_op.py @@ -22,7 +22,7 @@ def fake_quant(x, scale): return x -@pytest.mark.parametrize("kind", ["ABS", "SIN", "SUB", "MUL", "FUSE_ADD_TANH"]) +@pytest.mark.parametrize("kind", ["abs", "sin", "sub", "mul", "fuse_add_tanh"]) def test_elemwise(kind): x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x1_scale = np.float32(np.random.rand() + 1) @@ -39,8 +39,8 @@ def test_elemwise(kind): output_scale = np.float32(np.random.rand() + 1) output_dtype = dtype.qint8(output_scale) - quantized_kind = "Q" + kind - if kind in ("ABS", "SIN"): + quantized_kind = "q" + kind + if kind in ("abs", "sin"): desired_out = fake_quant(_elwise(x1, mode=kind), output_scale) actual_out = ( _elemwise_multi_type( @@ -84,7 +84,7 @@ def test_conv_bias(): SH, SW, has_bias=True, - nonlinear_mode="IDENTITY", + nonlinear_mode="identity", ): inp_v = np.random.normal(size=(N, IC, IH, IW)) w_v = np.random.normal(size=(OC, IC, KH, KW)) @@ -116,7 +116,7 @@ def test_conv_bias(): O = F.conv2d( inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), ) - if nonlinear_mode == "RELU": + if nonlinear_mode == "relu": return F.relu(O) else: return O @@ -158,5 +158,5 @@ def test_conv_bias(): run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) - run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU") - run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu") + run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") diff --git a/imperative/python/test/unit/utils/test_network_node.py b/imperative/python/test/unit/utils/test_network_node.py index 34e7c599f31b311e4ab7b588f21287a43ec81d67..99b178662c0bedbf835154d3ee7db40d8d5409fb 100644 --- a/imperative/python/test/unit/utils/test_network_node.py +++ b/imperative/python/test/unit/utils/test_network_node.py @@ -280,7 +280,7 @@ def test_convbias(): @trace(symbolic=True, capture_as_const=True) def fwd(inp, weight, bias): return F.quantized.conv_bias_activation( - inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" + inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="relu" ) inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) @@ -297,7 +297,7 @@ def test_batch_convbias(): @trace(symbolic=True, capture_as_const=True) def fwd(inp, weight, bias): return F.quantized.batch_conv_bias_activation( - inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" + inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="relu" ) inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) @@ -358,7 +358,7 @@ def test_warpaffine(): @trace(symbolic=True, capture_as_const=True) def fwd(x, weightv): - return F.vision.warp_affine(x, weightv, (2, 2), border_mode="WRAP") + return F.vision.warp_affine(x, weightv, (2, 2), border_mode="wrap") outp = fwd(x, weightv) check_pygraph_dump(fwd, [x, weightv], [outp]) @@ -387,7 +387,7 @@ def test_resize(): @trace(symbolic=True, capture_as_const=True) def fwd(x): - return F.vision.interpolate(x, size=(16, 16), mode="BILINEAR") + return F.vision.interpolate(x, size=(16, 16), mode="bilinear") out = fwd(x) check_pygraph_dump(fwd, [x], [out]) @@ -697,7 +697,7 @@ def test_assert_equal(): def test_elemwise_multitype(): - op = builtin.ElemwiseMultiType(mode="QADD", dtype=dtype.qint32(2.0)) + op = builtin.ElemwiseMultiType(mode="qadd", dtype=dtype.qint32(2.0)) @trace(symbolic=True, capture_as_const=True) def fwd(x, y):