提交 65af9cfc 编写于 作者: M Megvii Engine Team

refactor(mge): use lower case for default string parameters in functional and module

GitOrigin-RevId: dbc1f27ff722c8006962a0440f824cec02057ab6
上级 d0aa9b41
......@@ -65,7 +65,7 @@ def _elwise(*args, mode):
def _matmul(inp1, inp2):
op = builtin.MatrixMul(
transposeA=False, transposeB=False, compute_mode="DEFAULT", format="DEFAULT"
transposeA=False, transposeB=False, compute_mode="default", format="default"
)
inp1, inp2 = utils.convert_inputs(inp1, inp2)
(result,) = apply(op, inp1, inp2)
......@@ -178,7 +178,7 @@ def _reduce(mode):
def f(self, axis=None, keepdims: bool = False):
data = self
(data,) = utils.convert_inputs(data)
if mode == "MEAN":
if mode == "mean":
data = data.astype("float32")
elif self.dtype == np.bool_:
data = data.astype("int32")
......@@ -204,7 +204,7 @@ def _reduce(mode):
if not keepdims:
result = _remove_axis(result, axis)
if self.dtype == np.bool_:
if mode in ["MIN", "MAX"]:
if mode in ["min", "max"]:
result = result.astype("bool")
if axis is None or self.ndim == 1:
setscalar(result)
......@@ -479,7 +479,7 @@ class ArrayMethodMixin(abc.ABC):
10.0
"""
return _reduce("SUM")(self, axis, keepdims)
return _reduce("sum")(self, axis, keepdims)
def prod(self, axis=None, keepdims: bool = False):
r"""
......@@ -512,7 +512,7 @@ class ArrayMethodMixin(abc.ABC):
24.0
"""
return _reduce("PRODUCT")(self, axis, keepdims)
return _reduce("product")(self, axis, keepdims)
def min(self, axis=None, keepdims: bool = False):
r"""
......@@ -545,7 +545,7 @@ class ArrayMethodMixin(abc.ABC):
1.0
"""
return _reduce("MIN")(self, axis, keepdims)
return _reduce("min")(self, axis, keepdims)
def max(self, axis=None, keepdims: bool = False):
r"""
......@@ -578,7 +578,7 @@ class ArrayMethodMixin(abc.ABC):
4.0
"""
return _reduce("MAX")(self, axis, keepdims)
return _reduce("max")(self, axis, keepdims)
def mean(self, axis=None, keepdims: bool = False):
r"""
......@@ -611,4 +611,4 @@ class ArrayMethodMixin(abc.ABC):
2.5
"""
return _reduce("MEAN")(self, axis, keepdims)
return _reduce("mean")(self, axis, keepdims)
......@@ -267,6 +267,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
1.5
"""
norm = norm.upper()
assert norm in ["L1", "L2"], "norm must be L1 or L2"
# Converts binary labels to -1/1 labels.
loss = relu(1.0 - pred * label)
......
......@@ -604,9 +604,9 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor:
"""
assert len(inp.shape) <= 2, "Input should be 1d or 2d"
if descending:
order = "DESCENDING"
order = "descending"
else:
order = "ASCENDING"
order = "ascending"
op = builtin.Argsort(order=order)
if len(inp.shape) == 1:
......@@ -646,9 +646,9 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
"""
assert len(inp.shape) <= 2, "Input should be 1d or 2d"
if descending:
order = "DESCENDING"
order = "descending"
else:
order = "ASCENDING"
order = "ascending"
op = builtin.Argsort(order=order)
if len(inp.shape) == 1:
......@@ -699,11 +699,11 @@ def topk(
inp = -inp
if kth_only:
mode = "KTH_ONLY"
mode = "kth_only"
elif no_sort:
mode = "VALUE_IDX_NOSORT"
mode = "value_idx_nosort"
else:
mode = "VALUE_IDX_SORTED"
mode = "value_idx_sorted"
op = builtin.TopK(mode=mode)
if not isinstance(k, Tensor):
......@@ -765,8 +765,8 @@ def matmul(
inp2: Tensor,
transpose_a=False,
transpose_b=False,
compute_mode="DEFAULT",
format="DEFAULT",
compute_mode="default",
format="default",
) -> Tensor:
"""
Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
......@@ -776,7 +776,9 @@ def matmul(
- Both 1-D tensor, simply forward to ``dot``.
- Both 2-D tensor, normal matrix multiplication.
- If one input tensor is 1-D, matrix vector multiplication.
- If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted. For example:
- If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2,
the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted.
For example:
- inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
- inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
......
......@@ -52,6 +52,8 @@ __all__ = [
"deformable_psroi_pooling",
"dropout",
"embedding",
"hsigmoid",
"hswish",
"indexing_one_hot",
"leaky_relu",
"linear",
......@@ -62,17 +64,14 @@ __all__ = [
"max_pool2d",
"one_hot",
"prelu",
"softmax",
"softplus",
"sync_batch_norm",
"conv1d",
"sigmoid",
"hsigmoid",
"relu",
"relu6",
"hswish",
"resize",
"remap",
"resize",
"sigmoid",
"softmax",
"softplus",
"sync_batch_norm",
"warp_affine",
"warp_perspective",
]
......@@ -106,6 +105,83 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor
return ret
def conv1d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
"""1D convolution operation.
Refer to :class:`~.Conv1d` for more information.
:param inp: The feature map of the convolution operation
:param weight: The convolution kernel
:param bias: The bias added to the result of convolution (if given)
:param stride: Stride of the 1D convolution operation. Default: 1
:param padding: Size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: Dilation of the 1D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups,
in_channels // groups, height, width)``.
:type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode`
:param conv_mode: Supports 'cross_correlation'. Default:
'cross_correlation'.
:type compute_mode: string or
:class:`mgb.opr_param_defs.Convolution.ComputeMode`
:param compute_mode: When set to 'default', no special requirements will be
placed on the precision of intermediate results. When set to 'float32',
float32 would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype.
"""
assert (
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
assert inp.ndim == 3, "the input dimension of conv1d should be 3"
assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
inp = expand_dims(inp, 3)
weight = expand_dims(weight, 3)
if bias is not None:
assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
bias = expand_dims(bias, 3)
stride_h = stride
pad_h = padding
dilate_h = dilation
sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution(
stride_h=stride_h,
stride_w=1,
pad_h=pad_h,
pad_w=0,
dilate_h=dilate_h,
dilate_w=1,
strategy=get_execution_strategy(),
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
output = squeeze(output, 3)
return output
def conv2d(
inp: Tensor,
weight: Tensor,
......@@ -114,8 +190,8 @@ def conv2d(
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode="CROSS_CORRELATION",
compute_mode="DEFAULT",
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
"""
2D convolution operation.
......@@ -135,24 +211,27 @@ def conv2d(
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
:type conv_mode: string or :class:`Convolution.Mode`
:param conv_mode: supports "CROSS_CORRELATION". Default:
"CROSS_CORRELATION"
:param conv_mode: supports "cross_correlation". Default:
"cross_correlation"
:type compute_mode: string or
:class:`Convolution.ComputeMode`
:param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype.
:param compute_mode: when set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype.
:return: output tensor.
"""
assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT"
assert (
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP"
sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution(
stride_h=stride_h,
stride_w=stride_w,
......@@ -180,7 +259,7 @@ def conv3d(
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
conv_mode: str = "CROSS_CORRELATION",
conv_mode: str = "cross_correlation",
) -> Tensor:
"""
3D convolution operation.
......@@ -194,15 +273,16 @@ def conv3d(
:param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 3D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1,
:param groups: number of groups into which the input and output channels are divided,
so as to perform a ``grouped convolution``. When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, t, height, width)`.
:param conv_mode: supports "CROSS_CORRELATION". Default:
"CROSS_CORRELATION"
:param conv_mode: supports "cross_correlation". Default:
"cross_correlation"
:return: output tensor.
"""
assert conv_mode == "CROSS_CORRELATION"
assert conv_mode.lower() == "cross_correlation"
D, H, W = 0, 1, 2
......@@ -210,7 +290,7 @@ def conv3d(
stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP"
sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3D(
pad_d=pad[D],
pad_h=pad[H],
......@@ -240,8 +320,8 @@ def conv_transpose2d(
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode="CROSS_CORRELATION",
compute_mode="DEFAULT",
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
"""
2D transposed convolution operation.
......@@ -261,18 +341,21 @@ def conv_transpose2d(
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`. Default: 1
:type conv_mode: string or :class:`Convolution.Mode`
:param conv_mode: supports "CROSS_CORRELATION". Default:
"CROSS_CORRELATION"
:param conv_mode: supports "cross_correlation". Default:
"cross_correlation"
:type compute_mode: string or
:class:`Convolution.ComputeMode`
:param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype.
:param compute_mode: when set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype.
:return: output tensor.
"""
assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT"
assert (
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
if groups != 1:
raise NotImplementedError("group transposed conv2d is not supported yet.")
......@@ -307,8 +390,8 @@ def deformable_conv2d(
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode="CROSS_CORRELATION",
compute_mode="DEFAULT",
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
"""
Deformable Convolution.
......@@ -328,24 +411,27 @@ def deformable_conv2d(
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`. Default: 1
:type conv_mode: string or :class:`Convolution.Mode`
:param conv_mode: supports "CROSS_CORRELATION". Default:
"CROSS_CORRELATION"
:param conv_mode: supports "cross_correlation". Default:
"cross_correlation"
:type compute_mode: string or
:class:`Convolution.ComputeMode`
:param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype.
:param compute_mode: when set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype.
:return: output tensor.
"""
assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT"
assert (
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP"
sparse_type = "dense" if groups == 1 else "group"
op = builtin.DeformableConv(
stride_h=stride_h,
stride_w=stride_w,
......@@ -372,10 +458,13 @@ def local_conv2d(
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
conv_mode="CROSS_CORRELATION",
conv_mode="cross_correlation",
):
"""Applies spatial 2D convolution over an groupped channeled image with untied kernels."""
assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
assert (
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
......@@ -389,8 +478,8 @@ def local_conv2d(
dilate_h=dilate_h,
dilate_w=dilate_w,
mode=conv_mode,
compute_mode="DEFAULT",
sparse="DENSE",
compute_mode="default",
sparse="dense",
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
......@@ -430,7 +519,7 @@ def max_pool2d(
stride_w=stride_w,
pad_h=padding_h,
pad_w=padding_w,
mode="MAX",
mode="max",
)
(output,) = apply(op, inp)
return output
......@@ -441,7 +530,7 @@ def avg_pool2d(
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Union[int, Tuple[int, int]] = 0,
mode: str = "AVERAGE_COUNT_EXCLUDE_PADDING",
mode: str = "average_count_exclude_padding",
) -> Tensor:
"""
Applies 2D average pooling over an input tensor.
......@@ -453,7 +542,8 @@ def avg_pool2d(
:param stride: stride of the window. If not provided, its value is set to ``kernel_size``.
Default: None
:param padding: implicit zero padding added on both sides. Default: 0
:param mode: whether to count padding values. Default: "AVERAGE_COUNT_EXCLUDE_PADDING"
:param mode: whether to count padding values, set to "average" will do counting.
Default: "average_count_exclude_padding"
:return: output tensor.
"""
if stride is None:
......@@ -490,7 +580,7 @@ def adaptive_max_pool2d(
if isinstance(oshp, int):
oshp = (oshp, oshp)
op = builtin.AdaptivePooling(mode="MAX", format="NCHW",)
op = builtin.AdaptivePooling(mode="max", format="NCHW",)
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
(output,) = apply(op, inp, oshp)
return output
......@@ -511,7 +601,7 @@ def adaptive_avg_pool2d(
if isinstance(oshp, int):
oshp = (oshp, oshp)
op = builtin.AdaptivePooling(mode="AVERAGE", format="NCHW",)
op = builtin.AdaptivePooling(mode="average", format="NCHW",)
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
(output,) = apply(op, inp, oshp)
return output
......@@ -556,6 +646,53 @@ def deformable_psroi_pooling(
return output
def hswish(x):
"""
Element-wise `x * relu6(x + 3) / 6`.
:param x: input tensor.
:return: computed tensor.
Example:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(5).astype(np.float32))
out = F.hswish(x)
print(out.numpy().round(decimals=4))
.. testoutput::
[0. 0.6667 1.6667 3. 4. ]
"""
return _elwise(x, mode=Elemwise.Mode.H_SWISH)
def sigmoid(x):
"""Element-wise `1 / ( 1 + exp( -x ) )`."""
return _elwise(x, mode=Elemwise.Mode.SIGMOID)
def hsigmoid(x):
"""Element-wise `relu6(x + 3) / 6`."""
return relu6(x + 3) / 6
def relu(x):
"""Element-wise `max(x, 0)`."""
return _elwise(x, mode=Elemwise.Mode.RELU)
def relu6(x):
"""Element-wise `min(max(x, 0), 6)`."""
return minimum(maximum(x, 0), 6)
def prelu(inp: Tensor, weight: Tensor) -> Tensor:
r"""
Applies the element-wise PReLU function.
......@@ -872,14 +1009,14 @@ def batch_norm(
if not training:
op = builtin.BatchNorm(
fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="DIM_1C11"
fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11"
)
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
return ret
else:
op = builtin.BatchNorm(
avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11"
avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11"
)
if has_mean or has_var:
running_mean = make_full_if_none(running_mean, 0)
......@@ -915,7 +1052,7 @@ def sync_batch_norm(
training: bool = False,
momentum: Union[float, Tensor] = 0.9,
eps: float = 1e-5,
eps_mode="ADDITIVE",
eps_mode="additive",
group=WORLD,
) -> Tensor:
r"""
......@@ -939,7 +1076,9 @@ def sync_batch_norm(
Default: 1e-5
:return: output tensor.
"""
assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode)
assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format(
eps_mode
)
_channels = inp.shape[1]
_ndim = inp.ndim
_device = inp.device
......@@ -979,7 +1118,7 @@ def sync_batch_norm(
channel_mean = running_mean.reshape(*_param_shape)
invsqrt_channel_variance = (
maximum(channel_variance, eps) if eps_mode == "MAX" else channel_variance + eps
maximum(channel_variance, eps) if eps_mode == "max" else channel_variance + eps
) ** -0.5
if weight is not None:
......@@ -1019,13 +1158,16 @@ def sync_batch_norm(
return outvar
def one_hot(inp: Tensor, num_classes: int) -> Tensor:
r"""
Performs one-hot encoding for the input tensor.
def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
"""
Returns a new tensor where each of the elements are randomly set to zero
with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True.
:param inp: input tensor.
:param num_classes: number of classes denotes the last dimension of the output tensor.
:return: output tensor.
:param drop_prob: probability to drop (set to zero) a single element.
:param training: the default behavior of ``dropout`` during training is to rescale the output,
then it can be replaced by an :class:`~.Identity` during inference. Default: True
:return: the output tensor
Examples:
......@@ -1035,51 +1177,33 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(1, 4, dtype=np.int32))
out = F.one_hot(x, num_classes=4)
x = tensor(np.ones(10, dtype=np.float32))
out = F.dropout(x, 1./3.)
print(out.numpy())
Outputs:
.. testoutput::
:options: +SKIP
[[0 1 0 0]
[0 0 1 0]
[0 0 0 1]]
"""
zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device)
ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device)
op = builtin.IndexingSetOneHot(axis=inp.ndim)
(result,) = apply(op, zeros_tensor, inp, ones_tensor)
return result
[1.5 1.5 0. 1.5 1.5 1.5 1.5 1.5 1.5 1.5]
def matmul(
inp1: Tensor,
inp2: Tensor,
transpose_a=False,
transpose_b=False,
compute_mode="DEFAULT",
format="DEFAULT",
) -> Tensor:
"""
Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
assert 0 <= drop_prob < 1
rv = uniform(size=inp.shape)
mask = rv > drop_prob
inp *= mask.astype(inp.dtype)
if training:
inp *= 1 / (1 - drop_prob)
return inp
With different inputs dim, this function behaves differently:
- Both 1-D tensor, simply forward to ``dot``.
- Both 2-D tensor, normal matrix multiplication.
- If one input tensor is 1-D, matrix vector multiplication.
- If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will
be broadcasted. For example:
- inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
- inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
- inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
def one_hot(inp: Tensor, num_classes: int) -> Tensor:
r"""
Performs one-hot encoding for the input tensor.
:param inp1: first matrix to be multiplied.
:param inp2: second matrix to be multiplied.
:param inp: input tensor.
:param num_classes: number of classes denotes the last dimension of the output tensor.
:return: output tensor.
Examples:
......@@ -1090,182 +1214,27 @@ def matmul(
from megengine import tensor
import megengine.functional as F
data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
out = F.matmul(data1, data2)
x = tensor(np.arange(1, 4, dtype=np.int32))
out = F.one_hot(x, num_classes=4)
print(out.numpy())
Outputs:
.. testoutput::
[[10. 13.]
[28. 40.]]
"""
remove_row, remove_col = False, False
inp1, inp2 = utils.convert_inputs(inp1, inp2)
dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication
if dim1 == 1 and dim2 == 1:
return dot(inp1, inp2)
# the underlying matmul op requires input dims to be at least 2
if dim1 == 1:
inp1 = expand_dims(inp1, 0)
dim1 = 2
remove_row = True
if dim2 == 1:
inp2 = expand_dims(inp2, 1)
dim2 = 2
remove_col = True
batch_shape = None
shape1 = inp1.shape
shape2 = inp2.shape
maxdim = dim1 if dim1 > dim2 else dim2
if dim1 >= 3 or dim2 >= 3:
if use_symbolic_shape():
if dim1 > dim2:
shape2 = concat([shape1[:-2], shape2[-2:]])
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = concat([shape2[:-2], shape1[-2:]])
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
(inp1,) = apply(
builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]])
)
(inp2,) = apply(
builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]])
)
else:
if dim1 > dim2:
shape2 = shape1[:-2] + shape2[-2:]
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = shape2[:-2] + shape1[-2:]
inp1 = broadcast_to(inp1, shape1)
if maxdim > 3:
batch_shape = shape1[:-2]
# compress inputs to 3d
inp1 = inp1.reshape((-1, shape1[-2], shape1[-1]))
inp2 = inp2.reshape((-1, shape2[-2], shape2[-1]))
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=get_execution_strategy(),
)
else:
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=get_execution_strategy(),
)
(result,) = apply(op, inp1, inp2)
if maxdim > 3:
if use_symbolic_shape():
(result,) = apply(
builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]])
)
else:
result = result.reshape(batch_shape + result.shape[-2:])
if remove_row:
result = squeeze(result, axis=-2)
if remove_col:
result = squeeze(result, axis=-1)
return result
[[0 1 0 0]
[0 0 1 0]
[0 0 0 1]]
def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
"""
Computes dot-product of two vectors ``inp1`` and ``inp2``.
inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
Refer to :func:`~.matmul` for more general usage.
:param inp1: first vector.
:param inp2: second vector.
:return: output value.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
data1 = tensor(np.arange(0, 6, dtype=np.float32))
data2 = tensor(np.arange(0, 6, dtype=np.float32))
out = F.dot(data1, data2)
print(out.numpy())
Outputs:
.. testoutput::
55.
zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device)
ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device)
"""
op = builtin.Dot()
inp1, inp2 = utils.convert_inputs(inp1, inp2)
assert (
inp1.ndim <= 1 and inp2.ndim <= 1
), "Input tensors for dot must be 1-dimensional or scalar"
(result,) = apply(op, inp1, inp2)
setscalar(result)
op = builtin.IndexingSetOneHot(axis=inp.ndim)
(result,) = apply(op, zeros_tensor, inp, ones_tensor)
return result
def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
"""
Returns a new tensor where each of the elements are randomly set to zero
with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True.
:param inp: input tensor.
:param drop_prob: probability to drop (set to zero) a single element.
:param training: the default behavior of ``dropout`` during training is to rescale the output,
then it can be replaced by an :class:`~.Identity` during inference. Default: True
:return: the output tensor
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.ones(10, dtype=np.float32))
out = F.dropout(x, 1./3.)
print(out.numpy())
Outputs:
.. testoutput::
:options: +SKIP
[1.5 1.5 0. 1.5 1.5 1.5 1.5 1.5 1.5 1.5]
"""
assert 0 <= drop_prob < 1
rv = uniform(size=inp.shape)
mask = rv > drop_prob
inp *= mask.astype(inp.dtype)
if training:
inp *= 1 / (1 - drop_prob)
return inp
def embedding(
inp: Tensor,
weight: Tensor,
......@@ -1334,128 +1303,6 @@ def indexing_one_hot(
return result
def conv1d(
inp: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
conv_mode="CROSS_CORRELATION",
compute_mode="DEFAULT",
) -> Tensor:
"""1D convolution operation.
Refer to :class:`~.Conv1d` for more information.
:param inp: The feature map of the convolution operation
:param weight: The convolution kernel
:param bias: The bias added to the result of convolution (if given)
:param stride: Stride of the 1D convolution operation. Default: 1
:param padding: Size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: Dilation of the 1D convolution operation. Default: 1
:param groups: number of groups to divide input and output channels into,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be ``(groups, out_channel // groups,
in_channels // groups, height, width)``.
:type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode`
:param conv_mode: Supports 'CROSS_CORRELATION'. Default:
'CROSS_CORRELATION'.
:type compute_mode: string or
:class:`mgb.opr_param_defs.Convolution.ComputeMode`
:param compute_mode: When set to 'DEFAULT', no special requirements will be
placed on the precision of intermediate results. When set to 'FLOAT32',
Float32 would be used for accumulator and intermediate result, but only
effective when input and output are of Float16 dtype.
"""
assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT"
assert inp.ndim == 3, "the input dimension of conv1d should be 3"
assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
inp = expand_dims(inp, 3)
weight = expand_dims(weight, 3)
if bias is not None:
assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
bias = expand_dims(bias, 3)
stride_h = stride
pad_h = padding
dilate_h = dilation
sparse_type = "DENSE" if groups == 1 else "GROUP"
op = builtin.Convolution(
stride_h=stride_h,
stride_w=1,
pad_h=pad_h,
pad_w=0,
dilate_h=dilate_h,
dilate_w=1,
strategy=get_execution_strategy(),
mode=conv_mode,
compute_mode=compute_mode,
sparse=sparse_type,
)
inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight)
if bias is not None:
output += bias
output = squeeze(output, 3)
return output
def hswish(x):
"""
Element-wise `x * relu6(x + 3) / 6`.
:param x: input tensor.
:return: computed tensor.
Example:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.arange(5).astype(np.float32))
out = F.hswish(x)
print(out.numpy().round(decimals=4))
.. testoutput::
[0. 0.6667 1.6667 3. 4. ]
"""
return _elwise(x, mode=Elemwise.Mode.H_SWISH)
def sigmoid(x):
"""Element-wise `1 / ( 1 + exp( -x ) )`."""
return _elwise(x, mode=Elemwise.Mode.SIGMOID)
def hsigmoid(x):
"""Element-wise `relu6(x + 3) / 6`."""
return relu6(x + 3) / 6
def relu(x):
"""Element-wise `max(x, 0)`."""
return _elwise(x, mode=Elemwise.Mode.RELU)
def relu6(x):
"""Element-wise `min(max(x, 0), 6)`."""
return minimum(maximum(x, 0), 6)
interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True)
roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True)
roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True)
......
......@@ -24,9 +24,9 @@ def conv_bias_activation(
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
nonlinear_mode="IDENTITY",
conv_mode="CROSS_CORRELATION",
compute_mode="DEFAULT",
nonlinear_mode="identity",
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
"""
Convolution bias with activation operation, only for inference.
......@@ -35,27 +35,30 @@ def conv_bias_activation(
:param weight: convolution kernel.
:param bias: bias added to the result of convolution
:param stride: stride of the 2D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0
:param padding: size of the paddings added to the input on both sides
of its spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
:param groups: number of groups into which the input and output channels are divided,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
:type conv_mode: string or :class:`Convolution.Mode`.
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION'
:param conv_mode: supports 'cross_correlation' or 'convolution'. Default:
'cross_correlation'
:param dtype: support for ``np.dtype``, Default: np.int8
:type compute_mode: string or
:class:`Convolution.ComputeMode`.
:param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype.
:param compute_mode: when set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result,
but only effective when input and output are of float16 dtype.
"""
ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP"
sparse_type = "dense" if groups == 1 else "group"
op = builtin.ConvBias(
stride_h=sh,
stride_w=sw,
......@@ -84,9 +87,9 @@ def batch_conv_bias_activation(
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
nonlinear_mode="IDENTITY",
conv_mode="CROSS_CORRELATION",
compute_mode="DEFAULT",
nonlinear_mode="identity",
conv_mode="cross_correlation",
compute_mode="default",
) -> Tensor:
"""
Batch convolution bias with activation operation, only for inference.
......@@ -95,27 +98,30 @@ def batch_conv_bias_activation(
:param weight: convolution kernel in batched way.
:param bias: bias added to the result of convolution
:param stride: stride of the 2D convolution operation. Default: 1
:param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0
:param padding: size of the paddings added to the input on both sides
of its spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 2D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
:param groups: number of groups into which the input and output channels are divided,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`.
:type conv_mode: string or :class:`Convolution.Mode`.
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION'
:param conv_mode: supports 'cross_correlation' or 'convolution'. Default:
'cross_correlation'
:param dtype: support for ``np.dtype``, Default: np.int8
:type compute_mode: string or
:class:`Convolution.ComputeMode`.
:param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype.
:param compute_mode: when set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result,
but only effective when input and output are of float16 dtype.
"""
ph, pw = _pair(padding)
sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP"
sparse_type = "dense" if groups == 1 else "group"
op = builtin.BatchConvBias(
stride_h=sh,
stride_w=sw,
......
......@@ -335,12 +335,8 @@ def split(inp, nsplits_or_sections, axis=0):
y = F.split(x, 3)
z = F.split(x, [6, 17], axis=1)
if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"):
print([tuple(i.shape.numpy().tolist()) for i in y])
print([tuple(i.shape.numpy().tolist()) for i in z])
else:
print([i.shape for i in y])
print([i.shape for i in z])
print([i.numpy().shape for i in y])
print([i.numpy().shape for i in z])
Outputs:
......
......@@ -46,6 +46,7 @@ def cvt_color(inp: Tensor, mode: str = ""):
[[[[0.86555195]]]]
"""
mode = mode.upper()
assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color"
mode = getattr(builtin.CvtColor.Mode, mode)
assert isinstance(mode, builtin.CvtColor.Mode)
......@@ -92,9 +93,8 @@ def roi_pooling(
[[[-0.1383 -0.1383]
[-0.5035 -0.5035]]]
"""
assert mode in ["max", "average"], "only max/average mode is supported"
assert mode.lower() in ["max", "average"], "only max/average mode is supported"
if isinstance(output_shape, int):
output_shape = (output_shape, output_shape)
......@@ -151,6 +151,7 @@ def roi_align(
[0.1359 0.1359]]]
"""
mode = mode.lower()
assert mode in ["max", "average"], "only max/average mode is supported"
if isinstance(output_shape, int):
output_shape = (output_shape, output_shape)
......@@ -244,9 +245,9 @@ def nms(
def remap(
inp: Tensor,
map_xy: Tensor,
border_mode: str = "REPLICATE",
border_mode: str = "replicate",
scalar: float = 0.0,
interp_mode: str = "LINEAR",
interp_mode: str = "linear",
) -> Tensor:
r"""
Applies remap transformation to batched 2D images.
......@@ -257,11 +258,11 @@ def remap(
:param inp: input image
:param map_xy: (batch, oh, ow, 2) transformation matrix
:param border_mode: pixel extrapolation method.
Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT",
"REFLECT_101", "WRAP".
Default: "replicate". Currently also support "constant", "reflect",
"reflect_101", "wrap".
:param scalar: value used in case of a constant border. Default: 0
:param interp_mode: interpolation methods.
Default: "LINEAR". Currently only support "LINEAR" mode.
Default: "linear". Currently only support "linear" mode.
:return: output tensor.
Examples:
......@@ -301,10 +302,10 @@ def warp_affine(
inp: Tensor,
weight: Tensor,
out_shape,
border_mode="REPLICATE",
border_mode="replicate",
border_val=0,
format="NHWC",
imode="LINEAR",
imode="linear",
):
"""
Batched affine transform on 2D images.
......@@ -313,13 +314,13 @@ def warp_affine(
:param weight: weight tensor.
:param out_shape: output tensor shape.
:param border_mode: pixel extrapolation method.
Default: "WRAP". Currently "CONSTANT", "REFLECT",
"REFLECT_101", "ISOLATED", "WRAP", "REPLICATE", "TRANSPARENT" are supported.
Default: "wrap". Currently "constant", "reflect",
"reflect_101", "isolated", "wrap", "replicate", "transparent" are supported.
:param border_val: value used in case of a constant border. Default: 0
:param format: "NHWC" as default based on historical concerns,
"NCHW" is also supported. Default: "NCHW".
:param imode: interpolation methods. Could be "LINEAR", "NEAREST", "CUBIC", "AREA".
Default: "LINEAR".
"NCHW" is also supported. Default: "NHWC".
:param imode: interpolation methods. Could be "linear", "nearest", "cubic", "area".
Default: "linear".
:return: output tensor.
.. note::
......@@ -340,9 +341,9 @@ def warp_perspective(
inp: Tensor,
M: Tensor,
dsize: Union[Tuple[int, int], int, Tensor],
border_mode: str = "REPLICATE",
border_mode: str = "replicate",
border_val: float = 0.0,
interp_mode: str = "LINEAR",
interp_mode: str = "linear",
) -> Tensor:
r"""
Applies perspective transformation to batched 2D images.
......@@ -359,11 +360,11 @@ def warp_perspective(
:param M: `(batch, 3, 3)` transformation matrix.
:param dsize: `(h, w)` size of the output image.
:param border_mode: pixel extrapolation method.
Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT",
"REFLECT_101", "WRAP".
Default: "replicate". Currently also support "constant", "reflect",
"reflect_101", "wrap".
:param border_val: value used in case of a constant border. Default: 0
:param interp_mode: interpolation methods.
Default: "LINEAR". Currently only support "LINEAR" mode.
Default: "linear". Currently only support "linear" mode.
:return: output tensor.
Note:
......@@ -409,7 +410,7 @@ def interpolate(
inp: Tensor,
size: Optional[Union[int, Tuple[int, int]]] = None,
scale_factor: Optional[Union[float, Tuple[float, float]]] = None,
mode: str = "BILINEAR",
mode: str = "bilinear",
align_corners: Optional[bool] = None,
) -> Tensor:
r"""
......@@ -419,9 +420,9 @@ def interpolate(
:param size: size of the output tensor. Default: None
:param scale_factor: scaling factor of the output tensor. Default: None
:param mode: interpolation methods, acceptable values are:
"BILINEAR", "LINEAR". Default: "BILINEAR"
"bilinear", "linear". Default: "bilinear"
:param align_corners: This only has an effect when `mode`
is "BILINEAR" or "LINEAR". Geometrically, we consider the pixels of the input
is "bilinear" or "linear". Geometrically, we consider the pixels of the input
and output as squares rather than points. If set to ``True``, the input
and output tensors are aligned by the center points of their corner
pixels, preserving the values at the corner pixels. If set to ``False``,
......@@ -455,10 +456,10 @@ def interpolate(
[3. 3.25 3.75 4. ]]]]
"""
mode = mode.upper()
if mode not in ["BILINEAR", "LINEAR"]:
mode = mode.lower()
if mode not in ["bilinear", "linear"]:
raise ValueError("interpolate only support linear or bilinear mode")
if mode not in ["BILINEAR", "LINEAR"]:
if mode not in ["bilinear", "linear"]:
if align_corners is not None:
raise ValueError(
"align_corners option can only be set in the bilinear/linear interpolating mode"
......@@ -471,16 +472,16 @@ def interpolate(
size is not None
and scale_factor is None
and not align_corners
and mode == "BILINEAR"
and mode == "bilinear"
and inp.ndim in [4, 5]
):
# fastpath for interpolate
op = builtin.Resize(imode="LINEAR", format="NCHW")
op = builtin.Resize(imode="linear", format="NCHW")
shape = astensor1d(size, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape)
return result
if mode == "LINEAR":
if mode == "linear":
inp = expand_dims(inp, 3)
if inp.ndim != 4:
......@@ -492,14 +493,14 @@ def interpolate(
if isinstance(scale_factor, (float, int)):
scale_factor = float(scale_factor)
if mode == "LINEAR":
if mode == "linear":
scale_factor = (scale_factor, float(1))
else:
scale_factor = (scale_factor, scale_factor)
else:
if mode == "LINEAR":
if mode == "linear":
raise ValueError(
"under LINEAR mode, scale_factor can only be single value"
"under linear mode, scale_factor can only be single value"
)
assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )"
......@@ -524,8 +525,8 @@ def interpolate(
if isinstance(size, int):
size = (size, 1)
else:
if mode == "LINEAR":
raise ValueError("under LINEAR mode, size can only be single value")
if mode == "linear":
raise ValueError("under linear mode, size can only be single value")
dsize = size
oh, ow = dsize[0], dsize[1]
......@@ -534,7 +535,7 @@ def interpolate(
if align_corners:
hscale = (ih - 1.0) / (oh - 1.0)
wscale = 1.0 * iw / ow
if mode != "LINEAR":
if mode != "linear":
wscale = (iw - 1.0) / (ow - 1.0)
row0 = concat(
[wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0
......@@ -570,8 +571,8 @@ def interpolate(
weight = broadcast_to(weight, (inp.shape[0], 3, 3))
weight = weight.astype("float32")
ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR")
if mode == "LINEAR":
ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
if mode == "linear":
ret = reshape(ret, ret.shape[0:3])
return ret
......
......@@ -24,7 +24,7 @@ class BatchMatMulActivation(Module):
in_features: int,
out_features: int,
bias: bool = True,
nonlinear_mode="IDENTITY",
nonlinear_mode="identity",
**kwargs
):
super().__init__(**kwargs)
......@@ -37,7 +37,7 @@ class BatchMatMulActivation(Module):
if bias:
b_shape = (out_features,)
self.bias = Parameter(np.zeros(b_shape, dtype=np.float32))
self.nonlinear_mode = nonlinear_mode
self.nonlinear_mode = nonlinear_mode.lower()
self.reset_parameters()
def _get_fanin(self):
......@@ -54,7 +54,7 @@ class BatchMatMulActivation(Module):
res = matmul(weight, x)
if self.bias is not None:
res += bias
if self.nonlinear_mode == "RELU":
if self.nonlinear_mode == "relu":
res = relu(res)
return res
......
......@@ -138,11 +138,11 @@ class Conv1d(_ConvNd):
out_channel // groups, in_channels // groups, *kernel_size)`.
:param bias: whether to add a bias onto the result of convolution. Default:
True
:param conv_mode: Supports `CROSS_CORRELATION`. Default:
`CROSS_CORRELATION`
:param compute_mode: When set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only
:param conv_mode: Supports `cross_correlation`. Default:
`cross_correlation`
:param compute_mode: When set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype.
Examples:
......@@ -176,8 +176,8 @@ class Conv1d(_ConvNd):
dilation: int = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
conv_mode: str = "cross_correlation",
compute_mode: str = "default",
**kwargs
):
kernel_size = kernel_size
......@@ -298,11 +298,11 @@ class Conv2d(_ConvNd):
out_channel // groups, in_channels // groups, *kernel_size)`.
:param bias: whether to add a bias onto the result of convolution. Default:
True
:param conv_mode: Supports `CROSS_CORRELATION`. Default:
`CROSS_CORRELATION`
:param compute_mode: When set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only
:param conv_mode: Supports `cross_correlation`. Default:
`cross_correlation`
:param compute_mode: When set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype.
Examples:
......@@ -336,8 +336,8 @@ class Conv2d(_ConvNd):
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
conv_mode: str = "cross_correlation",
compute_mode: str = "default",
**kwargs
):
kernel_size = _pair_nonzero(kernel_size)
......@@ -436,15 +436,16 @@ class Conv3d(_ConvNd):
:param padding: size of the paddings added to the input on both sides of its
spatial dimensions. Only zero-padding is supported. Default: 0
:param dilation: dilation of the 3D convolution operation. Default: 1
:param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
:param groups: number of groups into which the input and output channels are divided,
so as to perform a "grouped convolution". When ``groups`` is not 1,
``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and there would be an extra dimension at the beginning of the weight's
shape. Specifically, the shape of weight would be `(groups,
out_channel // groups, in_channels // groups, *kernel_size)`.
:param bias: whether to add a bias onto the result of convolution. Default:
True
:param conv_mode: Supports `CROSS_CORRELATION`. Default:
`CROSS_CORRELATION`
:param conv_mode: Supports `cross_correlation`. Default:
`cross_correlation`
Examples:
......@@ -477,7 +478,7 @@ class Conv3d(_ConvNd):
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
conv_mode: str = "cross_correlation",
):
kernel_size = _triple_nonzero(kernel_size)
stride = _triple_nonzero(stride)
......@@ -566,11 +567,11 @@ class ConvTranspose2d(_ConvNd):
out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1
:param bias: wether to add a bias onto the result of convolution. Default:
True
:param conv_mode: Supports `CROSS_CORRELATION`. Default:
`CROSS_CORRELATION`
:param compute_mode: When set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only
:param conv_mode: Supports `cross_correlation`. Default:
`cross_correlation`
:param compute_mode: When set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype.
"""
......@@ -584,8 +585,8 @@ class ConvTranspose2d(_ConvNd):
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
conv_mode: str = "cross_correlation",
compute_mode: str = "default",
**kwargs
):
kernel_size = _pair_nonzero(kernel_size)
......@@ -679,7 +680,7 @@ class LocalConv2d(Conv2d):
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode: str = "CROSS_CORRELATION",
conv_mode: str = "cross_correlation",
**kwargs
):
self.input_height = input_height
......@@ -758,11 +759,11 @@ class DeformableConv2d(_ConvNd):
out_channel // groups, in_channels // groups, *kernel_size)`.
:param bias: whether to add a bias onto the result of convolution. Default:
True
:param conv_mode: Supports `CROSS_CORRELATION`. Default:
`CROSS_CORRELATION`
:param compute_mode: When set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only
:param conv_mode: Supports `cross_correlation`. Default:
`cross_correlation`
:param compute_mode: When set to "default", no special requirements will be
placed on the precision of intermediate results. When set to "float32",
"float32" would be used for accumulator and intermediate result, but only
effective when input and output are of float16 dtype.
"""
......@@ -776,8 +777,8 @@ class DeformableConv2d(_ConvNd):
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
conv_mode: str = "cross_correlation",
compute_mode: str = "default",
**kwargs
):
kernel_size = _pair_nonzero(kernel_size)
......
......@@ -24,8 +24,8 @@ class _ConvBnActivation2d(Module):
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
conv_mode: str = "cross_correlation",
compute_mode: str = "default",
eps=1e-5,
momentum=0.9,
affine=True,
......
......@@ -18,58 +18,58 @@ class Elemwise(Module):
:param method: the elemwise method, support the following string.
It will do the normal elemwise operator for float.
* "ADD": a + b
* "FUSE_ADD_RELU": max(x+y, 0)
* "MUL": x * y
* "MIN": min(x, y)
* "MAX": max(x, y)
* "SUB": x - y
* "TRUE_DIV": x / y
* "FUSE_ADD_SIGMOID": sigmoid(x + y)
* "FUSE_ADD_TANH": tanh(x + y)
* "RELU": x > 0 ? x : 0
* "ABS": x > 0 ? x : -x
* "SIGMOID": sigmoid(x)
* "EXP": exp(x)
* "TANH": tanh(x)
* "FUSE_MUL_ADD3": x * y + z
* "FAST_TANH": x * (27. + x * x) / (27. + 9. * x * x)
* "NEGATE": -x
* "ACOS": acos(x)
* "ASIN": asin(x)
* "CEIL": ceil(x)
* "COS": cos(x)
* "EXPM1": expm1(x)
* "FLOOR": floor(x)
* "LOG": log(x)
* "LOG1P": log1p(x)
* "SIN": sin(x)
* "ROUND": round(x)
* "ERF": erf(x)
* "ERFINV": erfinv(x)
* "ERFC": erfc(x)
* "ERFCINV": erfcinv(x)
* "ABS_GRAD": abs_grad
* "FLOOR_DIV": floor_div
* "MOD": mod
* "SIGMOID_GRAD": sigmoid_grad
* "SWITCH_GT0": switch_gt0
* "TANH_GRAD": tanh_grad
* "LT": less
* "LEQ": leq
* "EQ": equal
* "POW": pow
* "LOG_SUM_EXP": log_sum_exp
* "FAST_TANH_GRAD": fast_tanh_grad
* "ATAN2": atan2
* "COND_LEQ_MOV": cond_leq_mov
* "H_SWISH": h_swish
* "FUSE_ADD_H_SWISH": h_swish(x+y)
* "H_SWISH_GRAD": h_swish_grad
* "AND": bool binary: x && y
* "OR": bool binary: x || y
* "XOR": bool binary: x ^ y
* "NOT": bool unary: ~x
* "add": a + b
* "fuse_add_relu": max(x+y, 0)
* "mul": x * y
* "min": min(x, y)
* "max": max(x, y)
* "sub": x - y
* "true_div": x / y
* "fuse_add_sigmoid": sigmoid(x + y)
* "fuse_add_tanh": tanh(x + y)
* "relu": x > 0 ? x : 0
* "abs": x > 0 ? x : -x
* "sigmoid": sigmoid(x)
* "exp": exp(x)
* "tanh": tanh(x)
* "fuse_mul_add3": x * y + z
* "fast_tanh": x * (27. + x * x) / (27. + 9. * x * x)
* "negate": -x
* "acos": acos(x)
* "asin": asin(x)
* "ceil": ceil(x)
* "cos": cos(x)
* "expm1": expm1(x)
* "floor": floor(x)
* "log": log(x)
* "log1p": log1p(x)
* "sin": sin(x)
* "round": round(x)
* "erf": erf(x)
* "erfinv": erfinv(x)
* "erfc": erfc(x)
* "erfcinv": erfcinv(x)
* "abs_grad": abs_grad
* "floor_div": floor_div
* "mod": mod
* "sigmoid_grad": sigmoid_grad
* "switch_gt0": switch_gt0
* "tanh_grad": tanh_grad
* "lt": less
* "leq": leq
* "eq": equal
* "pow": pow
* "log_sum_exp": log_sum_exp
* "fast_tanh_grad": fast_tanh_grad
* "atan2": atan2
* "cond_leq_mov": cond_leq_mov
* "h_swish": h_swish
* "fuse_add_h_swish": h_swish(x+y)
* "h_swish_grad": h_swish_grad
* "and": bool binary: x && y
* "or": bool binary: x || y
* "xor": bool binary: x ^ y
* "not": bool unary: ~x
"""
def __init__(self, method, **kwargs):
......
......@@ -27,7 +27,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule):
in_features: int,
out_features: int,
bias: bool = True,
nonlinear_mode="IDENTITY",
nonlinear_mode="identity",
dtype=None,
**kwargs
):
......
......@@ -34,8 +34,8 @@ class Conv2d(Float.Conv2d, QuantizedModule):
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
conv_mode: str = "cross_correlation",
compute_mode: str = "default",
dtype=None,
**kwargs
):
......@@ -53,7 +53,7 @@ class Conv2d(Float.Conv2d, QuantizedModule):
)
self.output_dtype = dtype
def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"):
def calc_conv_quantized(self, inp, nonlinear_mode="identity"):
inp_scale = dtype.get_scale(inp.dtype)
w_scale = dtype.get_scale(self.weight.dtype)
bias_scale = inp_scale * w_scale
......@@ -100,11 +100,11 @@ class Conv2d(Float.Conv2d, QuantizedModule):
return qconv
def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")
return self.calc_conv_quantized(inp, nonlinear_mode="identity")
class ConvRelu2d(Conv2d):
r"""Quantized version of :class:`~.qat.ConvRelu2d`."""
def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="RELU")
return self.calc_conv_quantized(inp, nonlinear_mode="relu")
......@@ -50,11 +50,11 @@ class ConvBn2d(_ConvBnActivation2d):
r"""Quantized version of :class:`~.qat.ConvBn2d`."""
def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")
return self.calc_conv_quantized(inp, nonlinear_mode="identity")
class ConvBnRelu2d(_ConvBnActivation2d):
r"""Quantized version of :class:`~.qat.ConvBnRelu2d`."""
def forward(self, inp):
return self.calc_conv_quantized(inp, nonlinear_mode="RELU")
return self.calc_conv_quantized(inp, nonlinear_mode="relu")
......@@ -16,7 +16,7 @@ class Elemwise(QuantizedModule):
def __init__(self, method, dtype=None, **kwargs):
super().__init__(**kwargs)
self.method = "Q" + method
self.method = "q" + method
self.output_dtype = dtype
def forward(self, *inps):
......
......@@ -16,9 +16,9 @@ fi
export MEGENGINE_LOGGING_LEVEL="ERROR"
pushd $(dirname "${BASH_SOURCE[0]}")/.. >/dev/null
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest $test_dirs -m 'not isolated_distributed'
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed'
if [[ "$TEST_PLAT" == cuda ]]; then
echo "test GPU pytest now"
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest $test_dirs -m 'isolated_distributed'
PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed'
fi
popd >/dev/null
......@@ -372,7 +372,7 @@ def test_interpolate_fastpath():
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.vision.interpolate(x, size=(16, 16), mode="BILINEAR")
y = F.vision.interpolate(x, size=(16, 16), mode="bilinear")
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy())
......
......@@ -162,7 +162,7 @@ def test_qadd():
x = tensor(x, dtype=dtype.qint8(inp_scale))
y = tensor(y, dtype=dtype.qint8(inp_scale))
result_mge = F.elemwise._elemwise_multi_type(
x, y, mode="QADD", dtype=dtype.qint8(outp_scale)
x, y, mode="qadd", dtype=dtype.qint8(outp_scale)
)
result_mge = result_mge.astype("float32").numpy()
result_expect = x.astype("float32").numpy() + y.astype("float32").numpy()
......
......@@ -140,8 +140,8 @@ def test_interpolate():
def linear_interpolate():
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
out = F.vision.interpolate(inp, scale_factor=2.0, mode="LINEAR")
out2 = F.vision.interpolate(inp, 4, mode="LINEAR")
out = F.vision.interpolate(inp, scale_factor=2.0, mode="linear")
out2 = F.vision.interpolate(inp, 4, mode="linear")
np.testing.assert_allclose(
out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
......@@ -170,13 +170,13 @@ def test_interpolate():
inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
with pytest.raises(ValueError):
F.vision.interpolate(inp, scale_factor=2.0, mode="LINEAR")
F.vision.interpolate(inp, scale_factor=2.0, mode="linear")
def inappropriate_scale_linear_interpolate():
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
with pytest.raises(ValueError):
F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="LINEAR")
F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="linear")
linear_interpolate()
many_batch_interpolate()
......@@ -339,18 +339,18 @@ def test_interpolate_fastpath():
]
for inp_shape, target_shape in test_cases:
x = tensor(np.random.randn(*inp_shape), dtype=np.float32)
out = F.vision.interpolate(x, target_shape, mode="BILINEAR")
out = F.vision.interpolate(x, target_shape, mode="bilinear")
assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1]
assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1]
# check value
x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32)
out = F.vision.interpolate(x, (15, 5), mode="BILINEAR")
out = F.vision.interpolate(x, (15, 5), mode="bilinear")
np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32))
np_x = np.arange(32)
x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1)
out = F.vision.interpolate(x, (1, 1), mode="BILINEAR")
out = F.vision.interpolate(x, (1, 1), mode="bilinear")
np.testing.assert_equal(out.item(), np_x.mean())
......@@ -374,7 +374,7 @@ def test_warp_affine():
inp_shape = (1, 3, 3, 3)
x = tensor(np.arange(27, dtype=np.float32).reshape(inp_shape))
weightv = [[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]]
outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="WRAP")
outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="wrap")
res = np.array(
[
[
......@@ -509,7 +509,7 @@ def test_conv_bias():
SH,
SW,
has_bias=True,
nonlinear_mode="IDENTITY",
nonlinear_mode="identity",
):
inp_v = np.random.normal(size=(N, IC, IH, IW))
w_v = np.random.normal(size=(OC, IC, KH, KW))
......@@ -541,7 +541,7 @@ def test_conv_bias():
O = F.conv2d(
inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
)
if nonlinear_mode == "RELU":
if nonlinear_mode == "relu":
return F.relu(O)
else:
return O
......@@ -583,8 +583,8 @@ def test_conv_bias():
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU")
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu")
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
@pytest.mark.skipif(
......
......@@ -23,8 +23,8 @@ def test_module_elemwise():
y = np.random.rand(100).astype("float32")
x, y = tensor(x), tensor(y)
np.testing.assert_almost_equal(
test_func("H_SWISH", x), F.hswish(x).numpy(), decimal=6
test_func("h_swish", x), F.hswish(x).numpy(), decimal=6
)
np.testing.assert_almost_equal(
test_func("ADD", x, y), F.add(x, y).numpy(), decimal=6
test_func("add", x, y), F.add(x, y).numpy(), decimal=6
)
......@@ -133,7 +133,7 @@ def test_dequant_stub():
np.testing.assert_allclose(q, fake_quant_normal.numpy())
@pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"])
@pytest.mark.parametrize("kind", ["cos", "relu", "add", "mul", "fuse_add_relu"])
def test_elemwise(kind):
normal_net = Float.Elemwise(kind)
normal_net.eval()
......@@ -167,7 +167,7 @@ def test_elemwise(kind):
x2_int8 = quant(x2, x2_scale)
# test correctness of `Float`, `QAT` and `Quantized`
if kind in ("ADD", "MUL", "FUSE_ADD_RELU"):
if kind in ("add", "mul", "fuse_add_relu"):
normal = normal_net(x1, x2)
qat_without_fakequant = qat_from_float(x1, x2)
fake_quant_normal = fake_quant_act(normal_net(x1, x2), act_scale)
......
......@@ -22,7 +22,7 @@ def fake_quant(x, scale):
return x
@pytest.mark.parametrize("kind", ["ABS", "SIN", "SUB", "MUL", "FUSE_ADD_TANH"])
@pytest.mark.parametrize("kind", ["abs", "sin", "sub", "mul", "fuse_add_tanh"])
def test_elemwise(kind):
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x1_scale = np.float32(np.random.rand() + 1)
......@@ -39,8 +39,8 @@ def test_elemwise(kind):
output_scale = np.float32(np.random.rand() + 1)
output_dtype = dtype.qint8(output_scale)
quantized_kind = "Q" + kind
if kind in ("ABS", "SIN"):
quantized_kind = "q" + kind
if kind in ("abs", "sin"):
desired_out = fake_quant(_elwise(x1, mode=kind), output_scale)
actual_out = (
_elemwise_multi_type(
......@@ -84,7 +84,7 @@ def test_conv_bias():
SH,
SW,
has_bias=True,
nonlinear_mode="IDENTITY",
nonlinear_mode="identity",
):
inp_v = np.random.normal(size=(N, IC, IH, IW))
w_v = np.random.normal(size=(OC, IC, KH, KW))
......@@ -116,7 +116,7 @@ def test_conv_bias():
O = F.conv2d(
inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
)
if nonlinear_mode == "RELU":
if nonlinear_mode == "relu":
return F.relu(O)
else:
return O
......@@ -158,5 +158,5 @@ def test_conv_bias():
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU")
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu")
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
......@@ -280,7 +280,7 @@ def test_convbias():
@trace(symbolic=True, capture_as_const=True)
def fwd(inp, weight, bias):
return F.quantized.conv_bias_activation(
inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU"
inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="relu"
)
inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0))
......@@ -297,7 +297,7 @@ def test_batch_convbias():
@trace(symbolic=True, capture_as_const=True)
def fwd(inp, weight, bias):
return F.quantized.batch_conv_bias_activation(
inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU"
inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="relu"
)
inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0))
......@@ -358,7 +358,7 @@ def test_warpaffine():
@trace(symbolic=True, capture_as_const=True)
def fwd(x, weightv):
return F.vision.warp_affine(x, weightv, (2, 2), border_mode="WRAP")
return F.vision.warp_affine(x, weightv, (2, 2), border_mode="wrap")
outp = fwd(x, weightv)
check_pygraph_dump(fwd, [x, weightv], [outp])
......@@ -387,7 +387,7 @@ def test_resize():
@trace(symbolic=True, capture_as_const=True)
def fwd(x):
return F.vision.interpolate(x, size=(16, 16), mode="BILINEAR")
return F.vision.interpolate(x, size=(16, 16), mode="bilinear")
out = fwd(x)
check_pygraph_dump(fwd, [x], [out])
......@@ -697,7 +697,7 @@ def test_assert_equal():
def test_elemwise_multitype():
op = builtin.ElemwiseMultiType(mode="QADD", dtype=dtype.qint32(2.0))
op = builtin.ElemwiseMultiType(mode="qadd", dtype=dtype.qint32(2.0))
@trace(symbolic=True, capture_as_const=True)
def fwd(x, y):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册