...
 
Commits (7)
    https://gitcode.net/megvii/megengine/-/commit/2f06d580b9d69415df34d114bbf67a61e006192c feat(xla): add topk and sort for xla 2023-07-31T10:04:07+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 0e881f30429a8d849ad9cdd0e0f47c3e0921ff97 https://gitcode.net/megvii/megengine/-/commit/c2b9d5428456601e4ef10d3372fdbfcc94b142b7 fix(lite): fix the possibility of obtaining incorrect host device type when c... 2023-08-07T16:04:02+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: db0c8c071239bc7f3d7afaf97203426550d06402 https://gitcode.net/megvii/megengine/-/commit/de084f92ba7780f63cf86174b05f56f34fd46181 fix(src/atlas): remove the limitation that the input of the om model must be ... 2023-08-07T20:03:53+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: ab3fce058e8ce1fa47e05142170b06a33bdcbf95 https://gitcode.net/megvii/megengine/-/commit/abb7f6eff9503bcb407099430e52b8a1668c0a2e fix(src/atlas): support static input batch and dynamic output batch 2023-08-07T20:04:00+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 78df430e68e417c64e1f5aae89bd39b1fb91cb60 https://gitcode.net/megvii/megengine/-/commit/35167e53665d5752350b4978ce3481b9f5b9e30a feat(imperative): add augs including emboss, sharpen, linearcontrast 2023-08-10T11:06:05+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: c050784b9d5be33932c483ea954418b2e4ba3310 https://gitcode.net/megvii/megengine/-/commit/801265e9c98a7ae6636ae345f4dfd8f6bca0824b fix(imperative): fix linearconst test 2023-08-10T11:06:12+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 52d64c35190de65270aadffef6c449945ed16804 https://gitcode.net/megvii/megengine/-/commit/66b79160d35b2710c00befede0c3fd729109e474 fix(src/gopt): fix padding channel pass bug that hasn't insert a subtensor be... 2023-08-11T14:04:12+08:00 Megvii Engine Team megengine@megvii.com GitOrigin-RevId: 01bd8c70e96f72e6c5089d73ec7da607462837db
......@@ -38,8 +38,12 @@ from .rnn import LSTM, RNN, LSTMCell, RNNCell
from .sequential import Sequential
from .sliding_window import SlidingWindow, SlidingWindowTranspose
from .vision import (
ActiveBlur,
AdditiveElemwise,
AdditiveGaussianNoise,
AdditiveLaplaceNoise,
AdditivePoissonNoise,
Emboss,
LinearContrast,
Sharpen,
)
import math
import numbers
from functools import lru_cache
import numpy as np
from ..core.ops import builtin
from ..core.tensor.utils import subgraph_fn
from ..functional import (
arange,
broadcast_to,
clip,
flatten,
full_like,
gather,
mul,
reshape,
zeros,
)
from ..functional.elemwise import abs, add, log
from ..functional.math import sign
from ..functional.nn import conv2d, pad
from ..functional.tensor import broadcast_to
from ..random.rng import RNG
from ..tensor import Tensor
......@@ -54,12 +72,11 @@ class AdditiveLaplaceNoise(AdditiveElemwise):
r"""Add random laplace noise to the input data.
Laplace noise is generated with given mean and std, sampled from Laplace distribution
ref to this page to learn more: https://en.wikipedia.org/wiki/Laplace_distribution
Args:
mean: laplace mean used to generate noise.
std: laplace standard deviation used to generate noise.
per_channel: Whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.
per_channel: whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.
seed: random number seed of generator
"""
......@@ -94,7 +111,7 @@ class AdditivePoissonNoise(AdditiveElemwise):
Args:
lam: lam parameter of poisson distribution used to generate noise.
per_channel: Whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.
per_channel: whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.
seed: random number seed of generator
"""
......@@ -125,9 +142,9 @@ class AdditiveGaussianNoise(AdditiveElemwise):
Gaussian noise is generated with given mean and std.
Args:
mean: Gaussian mean used to generate noise.
std: Gaussian standard deviation used to generate noise.
per_channel: Whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.
mean: gaussian mean used to generate noise.
std: gaussian standard deviation used to generate noise.
per_channel: whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.
seed: random number seed of generator
"""
......@@ -152,3 +169,344 @@ class AdditiveGaussianNoise(AdditiveElemwise):
assert isinstance(seed, int)
self._seed = seed
self.rng_func = RNG(seed).normal
def _get_value_range_of_dtype(dtype):
if not dtype.kind in ["f", "u", "i", "b"]:
raise Exception(
"Cannot estimate value range of dtype '%s' "
"(type: %s)" % (str(dtype), type(dtype))
)
if dtype.kind == "f":
finfo = np.finfo(dtype)
value_min = finfo.min
value_mid = 0.0
value_max = finfo.max
if dtype.kind == "u":
iinfo = np.iinfo(dtype)
value_min = iinfo.min
value_mid = iinfo.min + 0.5 * iinfo.max
value_max = iinfo.max
if dtype.kind == "i":
iinfo = np.iinfo(dtype)
value_min = iinfo.min
value_mid = -0.5
value_max = iinfo.max
if dtype.kind == "b":
value_min = 0
value_mid = None
value_max = 1
return value_min, value_mid, value_max
def _check_out_dtype(inp, input_dtype):
if input_dtype.name == "bool":
inp = inp > 0.5
elif input_dtype.name in ["uint8", "uint16", "int8", "int16", "int32", "float16"]:
min_dtype, _, max_dtype = _get_value_range_of_dtype(input_dtype)
inp = clip(inp, min_dtype, max_dtype)
inp = inp.astype(input_dtype)
return inp
class ActiveBlur(Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, inp):
assert isinstance(
inp, Tensor
), "expected input is megengine.Tensor, but got {}".format(type(inp))
if inp.format == "nchw" or inp.format == "default":
_norm_inp = inp
N, C, H, W = inp.shape
else:
raise RuntimeError(
"expect you create Tensor with format NCHW, got format is {}".format(
inp.format
)
)
kernel = self.get_kernel(_norm_inp, C)
pad_inp = pad(
_norm_inp, pad_width=((0, 0), (0, 0), (1, 1), (1, 1)), mode="reflect"
)
result = conv2d(pad_inp, kernel, groups=C)
result = _check_out_dtype(result, inp.dtype)
return result
def _get_parameter(self, param):
if isinstance(param, bool):
raise TypeError("The input parameter cannot be of bool value type. ")
if isinstance(param, (numbers.Integral, numbers.Real)):
return float(param)
elif isinstance(param, tuple):
assert len(param) == 2, (
"Expected parameter with type tuple to have exactly two "
"entries, but got %d." % len(param)
)
param = self.rng_func(param[0], param[1])
return float(param)
else:
raise TypeError("The input parameter has a wrong type. ")
def get_kernel(self, inp, c):
raise NotImplementedError()
@lru_cache(maxsize=None)
def _get_EmbossKernel_op(alpha, strength, *, dtype=None, device=None):
@subgraph_fn(
"EmbossKernel", dtype=dtype, device=device, nr_inputs=2, gopt_level=None,
)
def EmbossKernel(input, f, c):
inp_e, inp_n = input[0:2]
c_alp = c(alpha, dtype="float32")
c_sub_alp = c(1 - alpha, dtype="float32")
c_stg = c(strength, dtype="float32")
c_1 = c(1, dtype="int32")
c_2 = c(2, dtype="int32")
c_3 = c(3, dtype="int32")
def _subtensor(src, axis, begin, end):
items = ((axis, (begin is not None), (end is not None), False, False),)
args = ()
if begin is not None:
args += (begin,)
if end is not None:
args += (end,)
return f(builtin.Subtensor(items=items), src, *args)
def _kernel_init(x):
k_1 = _subtensor(x, 0, None, c_1)
k_2 = _subtensor(x, 0, c_1, c_2)
k_3 = _subtensor(x, 0, c_2, c_3)
k_11 = f("-", _subtensor(k_1, 1, None, c_1), c_stg)
k_12_21 = f("-", _subtensor(k_1, 1, c_1, c_2), c_stg)
k_23_32 = f("+", _subtensor(k_2, 1, c_2, c_3), c_stg)
k_33 = f("+", _subtensor(k_3, 1, c_2, c_3), c_stg)
k_13 = _subtensor(k_1, 1, c_2, c_3)
k_22 = _subtensor(k_2, 1, c_1, c_2)
k_31 = _subtensor(k_3, 1, None, c_1)
nk_1 = f(builtin.Concat(axis=1), k_11, k_12_21, k_13,)
nk_2 = f(builtin.Concat(axis=1), k_12_21, k_22, k_23_32,)
nk_3 = f(builtin.Concat(axis=1), k_31, k_23_32, k_33,)
return f(builtin.Concat(axis=0), nk_1, nk_2, nk_3,)
def _kernel_calc(k_e, k_n):
k1 = f("*", k_n, c_sub_alp)
k2 = f("*", k_e, c_alp)
return f("+", k1, k2)
kernel_effect = _kernel_init(inp_e)
kernel = _kernel_calc(kernel_effect, inp_n)
return (kernel,), (False,)
return EmbossKernel
class Emboss(ActiveBlur):
r"""overlay emboss effect and alpha-blend the result with the original input
The embossed version pronounces highlights and shadows, enhances the high-frequency information of the image, and retains the low-frequency information of the image
Args:
alpha: adjust visibility of embossed images. number or tuple of number, At ``0.0``, only the original image is visible, at ``1.0`` only its embossed version is visible. If a tuple ``(a, b)``, a random value will be sampled from the interval ``[a, b)``.
strength: emboss strength.Sane values are somewhere in the interval ``[0.0, 2.0)`` with ``1.0``, number or tuple of number, If a tuple ``(a, b)``, a random value will be sampled from the interval ``[a, b)``.
seed: random number seed of generator
Examples:
>>> import numpy as np
>>> inp = mge.tensor(np.random.randint(0, 255, size=(160,3,128,128)).astype("float32"))
>>> aug = mge.module.Emboss(alpha=(0.6, 0.8), strength=(0.6, 0.8), seed=1)
>>> out = aug(inp)
"""
def __init__(self, alpha, strength, seed=None):
assert seed is None or isinstance(seed, int)
super().__init__()
self.alpha = alpha
self.strength = strength
self.rng_func = RNG(seed).uniform
self.seed = seed
self.matrix_nochange = Tensor(
np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], dtype=np.float32)
)
self.matrix_effect = Tensor(
np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
)
def get_kernel(self, inp, c):
alpha = self._get_parameter(self.alpha)
strength = self._get_parameter(self.strength)
get_kernel_fn = _get_EmbossKernel_op(
alpha,
strength,
dtype=self.matrix_effect.dtype,
device=self.matrix_effect.device,
)
kernel, *_ = get_kernel_fn(self.matrix_effect, self.matrix_nochange)
kernel = broadcast_to(kernel, (c, 1, 1, kernel.shape[0], kernel.shape[1]))
return kernel
@lru_cache(maxsize=None)
def _get_SharpenKernel_op(alpha, lightness, *, dtype=None, device=None):
@subgraph_fn(
"SharpenKernel", dtype=dtype, device=device, nr_inputs=2, gopt_level=None,
)
def SharpenKernel(input, f, c):
inp_e, inp_n = input[0:2]
c_alp = c(alpha, dtype="float32")
c_sub_alp = c(1 - alpha, dtype="float32")
c_lts = c(lightness, dtype="float32")
c_1 = c(1, dtype="int32")
c_2 = c(2, dtype="int32")
c_3 = c(3, dtype="int32")
def _subtensor(src, axis, begin, end):
items = ((axis, (begin is not None), (end is not None), False, False),)
args = ()
if begin is not None:
args += (begin,)
if end is not None:
args += (end,)
return f(builtin.Subtensor(items=items), src, *args)
def _kernel_init(x):
k_1 = _subtensor(x, 0, None, c_1)
k_2 = _subtensor(x, 0, c_1, c_2)
k_3 = _subtensor(x, 0, c_2, c_3)
k_21 = _subtensor(k_2, 1, None, c_1)
k_22 = f("+", _subtensor(k_2, 1, c_1, c_2), c_lts)
k_23 = _subtensor(k_2, 1, c_2, c_3)
nk_2 = f(builtin.Concat(axis=1), k_21, k_22, k_23,)
return f(builtin.Concat(axis=0), k_1, nk_2, k_3,)
def _kernel_calc(k_e, k_n):
k1 = f("*", k_n, c_sub_alp)
k2 = f("*", k_e, c_alp)
return f("+", k1, k2)
kernel_effect = _kernel_init(inp_e)
kernel = _kernel_calc(kernel_effect, inp_n)
return (kernel,), (False,)
return SharpenKernel
class Sharpen(ActiveBlur):
r"""Sharpen images and alpha-blend the result with the original input.
Args:
alpha: adjust visibility of sharpened images. number or tuple of number, At ``0.0``, only the original image is visible, at ``1.0`` only its embossed version is visible. If a tuple ``(a, b)``, a random value will be sampled from the interval ``[a, b)``.
lightness: controls the brightness of sharpened images. Sane values are somewhere in the interval ``[0.5, 2.0)`` with ``1.0``, number or tuple of number, If a tuple ``(a, b)``, a random value will be sampled from the interval ``[a, b)``.
seed: random number seed of generator
Examples:
>>> import numpy as np
>>> inp = mge.tensor(np.random.randint(0, 255, size=(160,3,128,128)).astype("float32"))
>>> aug = mge.module.Sharpen(alpha=(0.6, 0.8), lightness=(0.6, 0.8), seed=1)
>>> out = aug(inp)
"""
def __init__(self, alpha, lightness, seed=None):
assert seed is None or isinstance(seed, int)
super().__init__()
self.alpha = alpha
self.lightness = lightness
self.rng_func = RNG(seed).uniform
self.seed = seed
self.matrix_nochange = Tensor(
np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], dtype=np.float32)
)
self.matrix_effect = Tensor(
np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.float32)
)
def get_kernel(self, inp, c):
alpha = self._get_parameter(self.alpha)
lightness = self._get_parameter(self.lightness)
get_kernel_fn = _get_SharpenKernel_op(
alpha,
lightness,
dtype=self.matrix_effect.dtype,
device=self.matrix_effect.device,
)
kernel, *_ = get_kernel_fn(self.matrix_effect, self.matrix_nochange)
kernel = broadcast_to(kernel, (c, 1, 1, kernel.shape[0], kernel.shape[1]))
return kernel
class LinearContrast(Module):
r"""Adjust contrast by scaling each pixel to ``127 + alpha*(v-127)``.
Args:
alpha: number or tuple of number. If a tuple ``(a, b)``, a random value will be sampled from the interval ``[a, b)``.
per_channel:whether to use (imagewise) the same sample(s) for all channels (False) or to sample value(s) for each channel (True). Setting this to True will therefore lead to different transformations per image and channel, otherwise only per image.
seed: random number seed of generator
Examples:
>>> import numpy as np
>>> inp = mge.tensor(np.random.randint(0, 255, size=(160,3,128,128)).astype("float32"))
>>> aug = mge.module.LinearContrast(alpha=(0.6, 0.8), per_channel=False, seed=1)
>>> out = aug(inp)
"""
def __init__(self, alpha, per_channel=False, seed=None):
super().__init__()
self.alpha = alpha
self.seed = seed
self.per_channel = per_channel
self.rng_func = RNG(seed).uniform
def _get_parameter(self, param, size):
if isinstance(param, bool):
raise TypeError("The input parameter cannot be of bool value type. ")
if isinstance(param, (numbers.Integral, numbers.Real)):
value = zeros(size, dtype="float32")
value = full_like(value, param)
return value
elif isinstance(param, tuple):
assert len(param) == 2, (
"Expected parameter with type tuple to have exactly two "
"entries, but got %d." % len(param)
)
value = self.rng_func(param[0], param[1], size)
return value
else:
raise TypeError("The input parameter has a wrong type. ")
def _get_table(self, size):
shape = (size, 1)
alpha = self._get_parameter(self.alpha, shape)
table = arange(255).astype("float32")
table = broadcast_to(table, (size, 255))
table = 127 + mul((table - 127), alpha)
return clip(table, 0, 255)
def forward(self, inp: Tensor) -> Tensor:
if inp.dtype.name == "uint8":
if self.per_channel is True:
flatten_inp = reshape(
inp, (inp.shape[0] * inp.shape[1], inp.shape[2] * inp.shape[3])
).astype("int32")
else:
flatten_inp = flatten(inp, 1).astype("int32")
table = self._get_table(flatten_inp.shape[0])
result = gather(table, 1, flatten_inp)
result = reshape(result, inp.shape).astype("uint8")
return result
else:
input_dtype = inp.dtype
_, center_value, _ = _get_value_range_of_dtype(input_dtype)
if self.per_channel is True:
size = (inp.shape[0], inp.shape[1], 1, 1)
else:
size = (inp.shape[0], 1, 1, 1)
alpha = self._get_parameter(self.alpha, size)
if input_dtype.kind in ["u", "i"]:
center_value = int(center_value)
result = center_value + mul(inp.astype("float32") - center_value, alpha)
result = result.astype(input_dtype)
return result
......@@ -192,6 +192,7 @@ class TraceResult:
dtype_to_str = {
"float16": "f16",
"float32": "f32",
"int8": "i8",
"int32": "i32",
"int64": "i64",
"uint8": "u8",
......@@ -417,6 +418,10 @@ def f32_attr(i):
return ir.FloatAttr.get(ir.F32Type.get(), i)
def bool_attr(i):
return ir.BoolAttr.get(i)
def precision_attr(lhs_prec, rhs_prec) -> ir.ArrayAttr:
lhs_prec = str(lhs_prec)
rhs_prec = str(rhs_prec)
......
......@@ -66,7 +66,7 @@ def _hslice_with_step_is_one(inp, slices):
def _hslice_with_any_step(inp, slices):
"""
if inp_shape is N-dim, slices should contain N slice, slice can not None
if inp_shape is N-dim, slices should contain N slice, slice can not None.
for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)]
"""
starts = [int(sl.start) for sl in slices]
......@@ -83,7 +83,7 @@ def _hslice_with_any_step(inp, slices):
def index_with_slices(inp, slices):
"""
if inp_shape is N-dim, slices should contain N slice, slice can be None
if inp_shape is N-dim, slices should contain N slice, slice can be None.
for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)] or [None, None]
"""
assert isinstance(slices, Sequence), f"{slices}"
......
......@@ -4,9 +4,13 @@ import numpy as np
from ...core._imperative_rt import ops as mops
from .. import ir_utils
from ..ir_utils import i64_attr
from ..ir_utils import bool_attr, i64_attr
from ..lib.mlir import ir
from ..lib.mlir.dialects import chlo, hlo
from ..utils import flatten_list
from .hlotensor import HLOTensor
from .indexing import ScatterDimensionNumbers, scatter
from .tensor import concat, expand_dims, fill, iota
from .utils import _can_broadcast_to, _shape_equal, register_lower_rule
......@@ -241,5 +245,192 @@ def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
).transpose(permutation)
def _sort_according_to_key(key, *vals, axis=-1, descending=True, is_stable=True):
"""
sort key and vals in the specified axis, return the sorted key and vals.
key and vals should have the same shape, then we reorder both key and vals according
to the value of the key.
example 1: (implement argsort)
inp: 1.7783 -> 0, -1.8184 -> 1, 1.0701 -> 2
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
[[0 1 2]
[0 1 2]]
axis: -1
descend: True
return: after reorder, 1.7783 -> 0, -1.8184 -> 1, 1.0701 -> 2
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
[[0 2 1]
[2 0 1]]
example 2:
inp:
[[0 2 1]
[2 0 1]]
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
axis: -1
descend: False
return:
[[0 1 2]
[0 1 2]]
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
"""
for val in vals:
assert _shape_equal(
key.shape, val.shape
), f"sort key and vals shape mismatch: {key.shape}, {val.shape}"
axis = axis + key.ndim if axis < 0 else axis
sorted_key = ir_utils.make_ir_type_according_meta(key.shape, key.dtype)
sorted_vals = [
ir_utils.make_ir_type_according_meta(val.shape, val.dtype) for val in vals
]
sort_op = hlo.SortOp(
[sorted_key, *sorted_vals],
[key.tensor, *[val.tensor for val in vals]],
dimension=i64_attr(axis),
is_stable=bool_attr(is_stable),
)
key_type = ir_utils.make_ir_type_according_meta(tuple(), key.dtype)
val_types = [
ir_utils.make_ir_type_according_meta(tuple(), val.dtype) for val in vals
]
arg_types = [key_type] + val_types
comparator = sort_op.comparator.blocks.append(
*flatten_list(zip(arg_types, arg_types))
)
with ir.InsertionPoint(comparator):
lhs = HLOTensor(comparator.arguments[0])
rhs = HLOTensor(comparator.arguments[1])
if descending:
hlo.ReturnOp([(lhs > rhs).tensor])
else:
hlo.ReturnOp([(lhs < rhs).tensor])
assert len(sort_op.results) == len(vals) + 1, f"{len(vals)}, {len(sort_op.results)}"
return (HLOTensor(ret) for ret in sort_op.results)
def argsort(inp, axis=-1, descending=True, is_stable=True):
"""
sort inp in the specfic axis, and return the sorted value and index
for example:
inp:
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
axis: -1
descend: True
return:
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
[[0 2 1]
[2 0 1]]
"""
axis = axis + inp.ndim if axis < 0 else axis
idx = iota(np.int32, inp.shape, axis)
return _sort_according_to_key(
inp, idx, axis=axis, descending=descending, is_stable=is_stable
)
@register_lower_rule(mops.Argsort)
def argsort_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert (
len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 2
), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}"
assert ctx.op.order in [
mops.Argsort.Order.DESCENDING,
mops.Argsort.Order.ASCENDING,
], f"{ctx.op.order}"
descending = ctx.op.order == mops.Argsort.Order.DESCENDING
axis = args[0].ndim - 1 # megengine only support sort in the last dimension
return argsort(args[0], axis, descending, is_stable=True)
@register_lower_rule("ArgsortBackward")
def argsort_backward_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert (
len(args) == 3 and len(ctx.vars_in) == 3 and len(ctx.vars_out) == 1
), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}"
dy, idx, x = args[0], args[1], args[2]
if _shape_equal(x.shape, dy.shape):
# for argsort backward
_, dx = _sort_according_to_key(
idx, dy, axis=-1, descending=False, is_stable=True
)
else:
# for topk backward, only support axis=-1 and the dx is 2d tensor
dx = fill(0, ctx.vars_out[0].shape, ctx.vars_out[0].dtype)
expander = iota(np.int32, idx.shape, dimension=0)
idx = expand_dims(idx, -1)
expander = expand_dims(expander, -1)
idx = concat([expander, idx], -1)
dnums = ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0, 1),
)
dx = scatter(dx, idx, dy, dnums, unique_indices=True)
return dx
def topk(inp, k, descending=True, kth_only=False, no_sort=False):
return [HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results]
"""
do topk in the last dimension of inp, for example:
inp.shape = (2, 3, 4), k = 2, out_shape = (2, 3, 2)
"""
assert k > 0, f"k of topk must bigger than 0, get {k}"
assert no_sort == False, f"no_sort must be False now"
assert kth_only == False, f"kth_only is not support now"
if descending == True:
out, idx = [
HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results
]
else:
inp = -inp
out, idx = [
HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results
]
out = -out
return out, idx
@register_lower_rule(mops.TopK)
def topk_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert (
len(args) == 2 and len(ctx.vars_in) == 2
), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}"
assert isinstance(
ctx.vars_in[1].bound_data, np.ndarray
), f"{ctx.vars_in[1].bound_data}"
k = int(ctx.vars_in[1].bound_data)
descending = True if k < 0 else False
k = -k if k < 0 else k
if ctx.op.mode == mops.TopK.Mode.VALUE_IDX_SORTED:
assert len(ctx.vars_out) == 2, f"{len(ctx.vars_out)}"
kth_only, no_sort = False, False
elif ctx.op.mode == mops.TopK.Mode.VALUE_IDX_NOSORT:
assert len(ctx.vars_out) == 2, f"{len(ctx.vars_out)}"
kth_only, no_sort = False, True
else:
assert (
ctx.op.mode == mops.TopK.Mode.KTH_ONLY
), f"invalid mode for topk, {ctx.op.mode}"
kth_only, no_sort = True, False
assert len(ctx.vars_out) == 1, f"{len(ctx.vars_out)}"
return topk(args[0], k, descending, kth_only, no_sort)
......@@ -79,14 +79,13 @@ def transpose(inp, permutation):
def expand_dims(inp, axis):
assert isinstance(axis, int), f"only int axis supported, get {axis}"
axis = (axis + inp.ndim) if axis < 0 else axis
assert axis >= 0 and axis <= inp.ndim, f"invalid axis {axis} for {inp.shape}"
assert (
axis >= -inp.ndim - 1 and axis <= inp.ndim
), f"invalid axis {axis} for {inp.shape}"
dst_shape = []
for i in range(inp.ndim):
if i == axis:
dst_shape.append(1)
dst_shape.append(inp.shape[i])
dst_shape = list(inp.shape)
insert_pos = axis if axis >= 0 else (axis + inp.ndim + 1)
dst_shape.insert(insert_pos, 1)
return inp.reshape(tuple(dst_shape))
......@@ -94,14 +93,29 @@ def expand_dims(inp, axis):
@register_lower_rule(mops.Dimshuffle)
def dim_shuffle_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1
permutation = ctx.op.pattern
return transpose(args[0], permutation)
# mge dimshuffle can do transpose and broadcast simutaneously
# for example:
# case1: (16, 32, 64) with pattern [0, 2, 1] -> (16, 64, 32)
# case2: (16, 32, 64) with pattern [0, -1, 2, -1, 1] -> (16, 1, 64, 1, 32)
# case3: (16, 1, 64, 1, 32) with pattern [0, 4, 2] -> (16, 32, 64)
pattern = ctx.op.pattern
inp = args[0]
if len(pattern) == inp.ndim:
permutation = pattern
return transpose(inp, permutation)
elif len(pattern) > inp.ndim:
permutation = [item for item in pattern if item != -1]
return transpose(inp, permutation).reshape(ctx.vars_out[0].shape)
else:
permutation = [i for i in range(inp.ndim) if i not in pattern] + list(pattern)
return transpose(inp, permutation).reshape(ctx.vars_out[0].shape)
def concat(inps, axis):
assert len(inps) > 0, f"concat inputs should not be empty"
if axis < 0:
axis = axis + inps[0].ndim[0]
axis = axis + inps[0].ndim
hlo_inps = [inp.tensor for inp in inps]
......@@ -175,6 +189,21 @@ def fill(value, shape, dtype):
return broadcast_to(HLOTensor(value, dtype=dtype), shape)
def iota(dtype, shape, dimension):
"""
do some thing like arange.
for example:
shape = (2, 3), dimension=1, output is [[0, 1, 2], [0, 1, 2]]
shape = (2, 3), dimension=-1, output is [[0, 0, 0], [1, 1, 1]]
"""
dimension = dimension + len(shape) if dimension < 0 else dimension
ret = hlo.IotaOp(
ir_utils.make_ir_type_according_meta(shape, dtype), ir_utils.i64_attr(dimension)
).results
assert len(ret) == 1, f"{len(ret)}"
return HLOTensor(ret[0])
@register_lower_rule(mops.Fill)
def fill_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1
......
import time
import platform
import numpy as np
import pytest
from megengine import Tensor
from megengine import Tensor, is_cuda_available
from megengine.functional import mean, zeros
from megengine.module import (
AdditiveGaussianNoise,
AdditiveLaplaceNoise,
AdditivePoissonNoise,
Emboss,
LinearContrast,
Sharpen,
)
......@@ -38,3 +42,50 @@ def test_AdditiveNoise(cls, per_channel, shape, format, seed):
aug_ref = cls(per_channel=per_channel, seed=seed)
aug_data_ref = aug_ref(input_tensor)
np.testing.assert_allclose(aug_data, aug_data_ref)
@pytest.mark.parametrize("cls", [Emboss, Sharpen])
@pytest.mark.parametrize(
"shape, format, dtype",
[
((128, 2, 160, 160), "default", np.uint8),
((128, 2, 160, 160), "default", np.float32),
],
)
@pytest.mark.parametrize(
"param1, param2", [(0.5, 0.7), (0.6, 0.8), ((0.6, 0.8), (0.6, 0.8)),],
)
@pytest.mark.parametrize("seed", [1024, None])
def test_blur(cls, shape, format, dtype, param1, param2, seed):
input_array = np.random.randint(0, 255, size=shape).astype(dtype)
input_tensor = Tensor(input_array, device="xpux", format=format)
aug = cls(param1, param2, seed=seed)
aug_data = aug(input_tensor)
if seed is not None: # fix rng seed
aug_ref = cls(param1, param2, seed=seed)
aug_data_ref = aug_ref(input_tensor)
np.testing.assert_allclose(aug_data, aug_data_ref)
@pytest.mark.require_ngpu(1)
@pytest.mark.parametrize("per_channel", [False, True])
@pytest.mark.parametrize(
"shape, format, dtype",
[
((128, 2, 160, 160), "default", np.uint8),
((128, 2, 160, 160), "default", np.float32),
],
)
@pytest.mark.parametrize("param1", [0.6, 0.8, (0.6, 0.8)])
@pytest.mark.parametrize("seed", [1024, None])
def test_LinearContrast(per_channel, shape, format, dtype, param1, seed):
input_array = np.random.randint(0, 255, size=shape).astype(dtype)
input_tensor = Tensor(input_array, device="xpux", format=format)
aug = LinearContrast(param1, per_channel=per_channel, seed=seed)
aug_data = aug(input_tensor)
if seed is not None: # fix rng seed
aug_ref = LinearContrast(param1, per_channel=per_channel, seed=seed)
aug_data_ref = aug_ref(input_tensor)
np.testing.assert_allclose(aug_data, aug_data_ref)
import numpy as np
import pytest
import megengine as mge
import megengine.autodiff as ad
......@@ -62,6 +63,7 @@ def lamb_update(
return exp_avg, exp_avg_sq, new_param
@pytest.mark.skip(reason="pytest aborted, the same as groupnorm")
def test_lamb():
op = LAMBUpdate(0.9, 0.999, 1, 1e-3, 0.4, 1e-8, True, False)
m_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
......
......@@ -31,7 +31,6 @@ def test_matmul():
return out, lhs.grad, rhs.grad
mge_rsts = func(lhs, rhs, dout)
mge_rsts[0].numpy()
xla_rsts = func(lhs, rhs, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
......@@ -79,3 +78,109 @@ def test_matmul():
tester((1, 2, 8, 7), (4, 2, 2, 9, 8), True, True)
tester((1, 8, 7), (4, 3, 2, 8, 9), True, False)
tester((1, 8, 7), (4, 3, 1, 9, 8), True, True)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_sort_and_argsort():
def tester(ishape, descending, dtype=None):
dtype = dtype or np.float32
inp1 = tensor(np.random.randn(*ishape), dtype=dtype)
inp2 = tensor(np.random.randn(*ishape), dtype=dtype)
dout = tensor(np.random.randn(*ishape), dtype=dtype)
gm = GradManager()
@jit.xla_trace(without_host=True)
def func(inp1, inp2, dout):
gm.attach([inp1, inp2])
with gm:
out, idx1 = F.sort(inp1, descending)
idx2 = F.argsort(inp2, -descending)
gm.backward(out, dout)
return out, idx1, idx2, inp1.grad
mge_rsts = func(inp1, inp2, dout)
xla_rsts = func(inp1, inp2, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)
for descending in [True, False]:
tester((16, 32), descending)
tester((16, 1), descending)
tester((1, 16), descending)
tester((1, 1), descending)
tester((16,), descending)
tester((1,), descending)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_topk():
def tester(ishape, k, descending, kth_only, no_sort, dtype=None):
dtype = dtype or np.float32
inp = tensor(np.random.randn(*ishape), dtype=dtype)
out, _ = F.topk(inp, k, descending, kth_only, no_sort)
dout = tensor(0.1 * np.random.randn(*out.shape), dtype=dtype)
gm = GradManager()
@jit.xla_trace(without_host=True)
def func(inp, dout):
gm.attach([inp])
with gm:
out, index = F.topk(inp, k, descending, kth_only, no_sort)
gm.backward(out, dout)
return out, index, inp.grad
mge_rsts = func(inp, dout)
xla_rsts = func(inp, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)
for descending in [True, False]:
tester((2, 16,), 1, descending, False, False)
tester((2, 16,), 8, descending, False, False)
tester((1, 16,), 1, descending, False, False)
tester((1, 16,), 5, descending, False, False)
tester((16,), 8, descending, False, False)
tester((16,), 8, descending, False, False)
tester((1,), 1, descending, False, False)
tester((1,), 1, descending, False, False)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_topk_accuracy():
def tester(batch, nr_class, topk, dtype=None):
dtype = dtype or np.float32
logits = tensor(np.random.uniform(0, 1, (batch, nr_class)), dtype=dtype)
target = tensor(np.random.randint(0, nr_class, (batch,), np.int32))
out = F.topk_accuracy(logits, target, topk)
dout = tensor(0.1 * np.random.randn(*out.shape), dtype=dtype)
gm = GradManager()
@jit.xla_trace(without_host=True)
def func(logits, target, dout):
gm.attach([logits])
with gm:
out = F.topk_accuracy(logits, target, topk)
gm.backward(out, dout)
return [out]
mge_rsts = func(logits, target, dout)
xla_rsts = func(logits, target, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)
tester(32, 1000, 10)
tester(32, 1, 1)
tester(1, 1000, 10)
tester(1, 1, 1)
......@@ -113,6 +113,13 @@ def test_transpose():
tester((2, 3, 1), (0, 1, 2))
tester((2, 3, 1, 4), (3, 1, 0, 2))
tester((1,), ("x", 0))
# tester((1,), (0, 'x')) # bug for mge
tester((1, 2), ("x", 0, 1))
tester((1, 2), (0, "x", 1))
# tester((1, 2), (0, 1, 'x')) # bug for mge
tester((16, 32, 64), (0, "x", 2, "x", 1))
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
......
......@@ -140,7 +140,8 @@ class TestNetwork(TestShuffleNetCuda):
network.load(self.model_path)
input_tensor = network.get_io_tensor("data")
assert input_tensor.device_type == LiteDeviceType.LITE_CPU
# the device type is cuda, but by default, the memory type is pinned memory on the host side, which is not on cuda.
assert input_tensor.device_type == LiteDeviceType.LITE_CUDA
self.do_forward(network)
......
......@@ -102,7 +102,7 @@ TensorImplDft::TensorImplDft(
LiteDeviceType TensorImplDft::get_device_type() const {
if (is_host()) {
return LiteDeviceType::LITE_CPU;
return get_device_from_locator(m_host_tensor->comp_node().locator());
} else {
return get_device_from_locator(m_dev_tensor->comp_node().locator());
}
......
......@@ -571,6 +571,17 @@ TEST(TestTensor, ConcatDevice) {
check(1);
check(2);
}
TEST(TestTensor, CudaOutputDevice) {
Layout layout{{1, 4}, 2};
bool is_pinned_host = true;
Tensor tensor(LiteDeviceType::LITE_CUDA, layout, is_pinned_host);
// If is_pinned_host is true, when calling update_from_implement(), the device type
// should always be updated with
// get_device_from_locator(m_host_tensor->comp_node().locator()).
tensor.update_from_implement();
ASSERT_EQ(tensor.get_device_type(), LiteDeviceType::LITE_CUDA);
}
#endif
#endif
......
......@@ -558,7 +558,7 @@ void PaddingChannelPass::add_condition_padding_oprs_replace_func(LayoutTrans) {
if (reduce->input().size() > 1) {
can_forward_padding = false;
} else {
can_forward_padding = reduce->param().axis != 1;
can_forward_padding = axis != 1;
}
} else if (auto subtensor = opr->try_cast_final<opr::Subtensor>()) {
auto indexs = subtensor->index_desc();
......@@ -605,6 +605,7 @@ void PaddingChannelPass::add_nonpadding_oprs_replace_func(LayoutTrans) {
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
m_opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::AxisAddRemove::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::Dimshuffle::typeinfo()] = replace_nonpadding_oprs;
......
......@@ -282,6 +282,85 @@ TEST(TestGoptInference, ChannelPaddingSubtensor) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2);
}
TEST(TestGoptInference, ChannelPaddingAxisAddRemove) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
auto host_x = gen({1, 3, 8, 8}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//! Hybrid nchw44 mode
opr::ConvBias::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3}), b1 = mkcvar("w1", {1, 8, 1, 1}),
conv1 = opr::ConvBias::make(
x, w1, b1, param_conv, {}, OperatorNodeConfig("conv1"));
auto w2 = mkcvar("w2", {1, 8, 1, 1}),
conv2 = opr::Convolution::make(conv1, w2, {}, {}, OperatorNodeConfig("conv2"));
auto remove_axis_1 = opr::AxisAddRemove::make(
conv2, {opr::AxisAddRemove::AxisDesc::make_remove(1)},
OperatorNodeConfig("remove_axis_1"));
auto add_axis_1 = opr::AxisAddRemove::make(
remove_axis_1, {opr::AxisAddRemove::AxisDesc::make_add(1)});
auto w3 = mkcvar("w3", {3, 1, 1, 1}),
conv3 = opr::Convolution::make(
add_axis_1, w3, {}, {}, OperatorNodeConfig("conv3"));
auto remove_axis_0 = opr::AxisAddRemove::make(
conv3, {opr::AxisAddRemove::AxisDesc::make_remove(0)},
OperatorNodeConfig("remove_axis_0"));
SymbolVar y_pad;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass(gopt::PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW44,
true))
.apply({{remove_axis_0}})
.endpoint_vars(),
y_pad);
auto conv1_opt = find_opr<opr::ConvBias>(y_pad, "conv1");
auto conv2_opt = find_opr<opr::Convolution>(y_pad, "conv2");
auto remove_axis_1_opt = find_opr<opr::AxisAddRemove>(y_pad, "remove_axis_1");
auto conv3_opt = find_opr<opr::Convolution>(y_pad, "conv3");
auto remove_axis_0_opt = find_opr<opr::AxisAddRemove>(y_pad, "remove_axis_0");
//! do not padding input tensor
ASSERT_EQ(conv1_opt->input(0)->shape()[1], 3);
//! output tensor padding input tensor
ASSERT_EQ(conv2_opt->input(1)->shape()[0], 4);
ASSERT_EQ(conv2_opt->output(0)->shape()[1], 4);
//! AxisAddRemove always add subtensor
ASSERT_EQ(remove_axis_1_opt->input(0)->shape()[1], 1);
ASSERT_EQ(conv3_opt->input(1)->shape()[0], 4);
ASSERT_EQ(conv3_opt->output(0)->shape()[1], 4);
//! AxisAddRemove always add subtensor
ASSERT_EQ(remove_axis_0_opt->input(0)->shape()[1], 3);
graph->compile({{y_pad, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.ChannelPaddingAxisAddRemove.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(remove_axis_0, host_y),
make_callback_copy(y_pad, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2);
//! test change the input shape
*host_x = *gen({1, 3, 32, 32}, cn);
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2);
}
TEST(TestGoptInference, ChannelPaddingReduce) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
......
......@@ -15,12 +15,30 @@ namespace {
/**
* \brief get mgb shape from acl shape, batch from mgb
*/
TensorShape acl_shape_to_mgb_shape_for_output(aclmdlIODims acl_shape, size_t batch) {
TensorShape acl_shape_to_mgb_shape_for_output(
aclmdlDesc* model_desc, size_t output_idx, size_t output_dtype_size,
aclmdlIODims acl_shape, size_t batch) {
TensorShape ret;
ret.ndim = acl_shape.dimCount;
for (size_t i = 0; i < ret.ndim; ++i) {
ret[i] = acl_shape.dims[i];
}
if (acl_shape.dims[0] == -1) {
batch = aclmdlGetOutputSizeByIndex(model_desc, output_idx);
size_t chw = output_dtype_size;
for (size_t i = 1; i < ret.ndim; ++i) {
chw *= ret[i];
}
mgb_assert(
batch % chw == 0,
"When the input batch is static and the output batch is dynamic, it is "
"necessary to reconfigure the output batch. The output size obtained "
"from the aclmdlGetOutputSizeByIndex interface should be evenly "
"divided by "
"shapes other than the batch. expect 0, but got %zu\n",
batch % chw);
batch /= chw;
}
ret[0] = batch;
return ret;
}
......@@ -40,9 +58,6 @@ TensorShape acl_shape_to_mgb_shape_for_input(
MGB_MARK_USED_VAR(aipp_input_fmt);
TensorShape ret;
ret.ndim = acl_shape.dimCount;
mgb_assert(
ret.ndim == 4, "Unexpected ndim form aclmdlIODims expected 4, but got %zu",
ret.ndim);
for (size_t i = 0; i < ret.ndim; ++i) {
ret[i] = acl_shape.dims[i];
}
......@@ -335,7 +350,7 @@ void AtlasRuntimeOpr::scn_do_execute() {
for (size_t i = 0; i < nr_outputs; i++) {
auto value_pair = output_getter.get(batch, i);
size_t output_size = value_pair.second;
if (enable_dynamic_batch) {
if (enable_dynamic_batch || m_dyn_batch_output[i]) {
output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i);
}
aclDataBuffer* output_db =
......@@ -346,6 +361,18 @@ void AtlasRuntimeOpr::scn_do_execute() {
"%zu:%s.",
i, output(i)->cname());
aclmdlAddDatasetBuffer(model_outputs, output_db);
if (m_dyn_batch_output[i]) {
auto tensor_ndim = output(0)->shape().ndim;
std::vector<int64_t> tensor_shape(tensor_ndim, 0);
for (size_t j = 0; j < tensor_ndim; j++) {
tensor_shape[j] = output(0)->shape()[j];
}
aclTensorDesc* tensorDesc = aclCreateTensorDesc(
aclmdlGetOutputDataType(m_model_desc, i), tensor_ndim,
tensor_shape.data(), aclmdlGetOutputFormat(m_model_desc, i));
aclmdlSetDatasetTensorDesc(model_outputs, tensorDesc, i);
}
}
MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs));
......@@ -354,6 +381,31 @@ void AtlasRuntimeOpr::scn_do_execute() {
MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
}
for (size_t i = 0; i < nr_outputs; ++i) {
if (m_dyn_batch_output[i]) {
const DeviceTensorND old_dev_tensor = output(i)->dev_tensor();
auto new_output_desc = aclmdlGetDatasetTensorDesc(model_outputs, i);
TensorShape new_shape;
new_shape.ndim = aclGetTensorDescNumDims(new_output_desc);
mgb_assert(
new_shape.ndim == old_dev_tensor.layout().ndim,
"for static input batch and dynamic output batch, the output "
"ndim should be consistent with the one before calling "
"aclmdlExecute(), so expect %zu, but got %zu",
old_dev_tensor.layout().ndim, new_shape.ndim);
for (size_t j = 0; j < new_shape.ndim; j++) {
new_shape.shape[j] = aclGetTensorDescDim(new_output_desc, j);
}
TensorLayout new_layout{
new_shape, old_dev_tensor.dtype(), old_dev_tensor.format()};
DeviceTensorND new_dev_tensor{
old_dev_tensor.comp_node(), new_layout, old_dev_tensor.dtype(),
old_dev_tensor.format()};
new_dev_tensor.reset(old_dev_tensor.storage(), new_layout);
output(i)->force_assign_dev_tensor_from_tensor(new_dev_tensor);
}
aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_outputs, i);
MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
}
......@@ -390,7 +442,9 @@ void AtlasRuntimeOpr::get_output_var_shape(
for (size_t i = 0; i < out_shape.size(); ++i) {
aclmdlIODims output_dims;
MGB_ATLAS_CHECK(aclmdlGetOutputDims(m_model_desc, i, &output_dims));
out_shape[i] = acl_shape_to_mgb_shape_for_output(output_dims, batch_size);
out_shape[i] = acl_shape_to_mgb_shape_for_output(
m_model_desc, i, output(i)->dtype().size(), output_dims, batch_size);
m_dyn_batch_output.push_back(output_dims.dims[0] == -1);
}
}
......
......@@ -64,6 +64,9 @@ private:
//! Atlas need a 64bit device tensor to hold dynamic batch state
DeviceTensorND m_dyn_batch_tensor;
SmallVector<size_t> m_dyn_batch_choices;
//! Used when the input batch is static and the output batch is dynamic. Different
//! from the case where the input batch is dynamic and the output batch is dynamic
mutable SmallVector<bool> m_dyn_batch_output;
};
} // namespace opr
......