提交 24c5c19b 编写于 作者: M Megvii Engine Team

fix(imperative): make functional ops support negative axis

GitOrigin-RevId: f61e01270b948ab5bd6ba32b091d7c6b8d7a0745
上级 c76e80bc
...@@ -1015,7 +1015,7 @@ class IndexingOneHotBase : public OperatorBase { ...@@ -1015,7 +1015,7 @@ class IndexingOneHotBase : public OperatorBase {
DEF_OPR_PARAM(Axis); DEF_OPR_PARAM(Axis);
protected: protected:
void deduce_layout_fwd( MGE_WIN_DECLSPEC_FUC void deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& index, TensorLayout& dst); const TensorLayout& src, const TensorLayout& index, TensorLayout& dst);
void check_layout_fwd( void check_layout_fwd(
const TensorLayout& src, const TensorLayout& index, const TensorLayout& src, const TensorLayout& index,
......
...@@ -1558,7 +1558,7 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: ...@@ -1558,7 +1558,7 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
) )
ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device) ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device)
op = builtin.IndexingSetOneHot(axis=inp.ndim) op = builtin.IndexingSetOneHot(axis=inp.ndim, ndim=inp.ndim)
(result,) = apply(op, zeros_tensor, inp, ones_tensor) (result,) = apply(op, zeros_tensor, inp, ones_tensor)
return result return result
...@@ -1609,7 +1609,7 @@ def indexing_one_hot( ...@@ -1609,7 +1609,7 @@ def indexing_one_hot(
array([1.], dtype=float32) array([1.], dtype=float32)
""" """
assert isinstance(src, Tensor), "src must be of Tensor type" assert isinstance(src, Tensor), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis) op = builtin.IndexingOneHot(axis=axis, ndim=src.ndim)
index = convert_single_value(index, dtype="int32", device=src.device) index = convert_single_value(index, dtype="int32", device=src.device)
(result,) = apply(op, src, index) (result,) = apply(op, src, index)
if not keepdims: if not keepdims:
......
...@@ -393,6 +393,8 @@ def split(inp, nsplits_or_sections, axis=0): ...@@ -393,6 +393,8 @@ def split(inp, nsplits_or_sections, axis=0):
def _get_idx(index, axis): def _get_idx(index, axis):
index_dims = len(index.shape) index_dims = len(index.shape)
idx = [] idx = []
if axis < 0:
axis += index_dims
for i in range(index_dims): for i in range(index_dims):
if i != axis: if i != axis:
shape = [1] * index_dims shape = [1] * index_dims
...@@ -457,21 +459,6 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: ...@@ -457,21 +459,6 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
"But the input dims:{}, the index dims:{}".format(input_dims, index_dims) "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
) )
if axis < 0 or axis >= input_dims:
raise ValueError(
"Index axis {} is output of bounds, should in range [0 {})".format(
axis, input_dims
)
)
for i in range(input_dims):
if i != axis and input_shape[i] != index_shape[i]:
raise ValueError(
"The input {} and index {} must have the same size apart from axis {}".format(
input_shape, index_shape, axis
)
)
idx = _get_idx(index, axis) idx = _get_idx(index, axis)
return inp[idx].reshape(index.shape) # pylint: disable=no-member return inp[idx].reshape(index.shape) # pylint: disable=no-member
...@@ -524,7 +511,7 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: ...@@ -524,7 +511,7 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
>>> inp = Tensor(np.zeros(shape=(3,5),dtype=np.float32)) >>> inp = Tensor(np.zeros(shape=(3,5),dtype=np.float32))
>>> source = Tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]]) >>> source = Tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
>>> index = Tensor([[0,2,0,2,1],[2,0,1,1,2]]) >>> index = Tensor([[0,2,0,2,1],[2,0,1,1,2]])
>>> oup = F.scatter(inp, 0, index,source) >>> oup = F.scatter(inp, 0, index, source)
>>> oup.numpy() >>> oup.numpy()
array([[0.9935, 0.0718, 0.2256, 0. , 0. ], array([[0.9935, 0.0718, 0.2256, 0. , 0. ],
[0. , 0. , 0.5939, 0.357 , 0.4396], [0. , 0. , 0.5939, 0.357 , 0.4396],
...@@ -540,13 +527,6 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: ...@@ -540,13 +527,6 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
if input_dims != index_dims or input_dims != source_dims: if input_dims != index_dims or input_dims != source_dims:
raise ValueError("The input, source and index tensor must have same dimensions") raise ValueError("The input, source and index tensor must have same dimensions")
if axis < 0 or axis >= input_dims:
raise ValueError(
"Index axis {} is output of bounds, should in range [0 {})".format(
axis, input_dims
)
)
for i in range(source_dims): for i in range(source_dims):
if source_shape[i] > input_shape[i]: if source_shape[i] > input_shape[i]:
raise ValueError( raise ValueError(
...@@ -792,6 +772,8 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: ...@@ -792,6 +772,8 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
>>> out.numpy().shape >>> out.numpy().shape
(2, 2, 9) (2, 2, 9)
""" """
if start_axis < 0:
start_axis += len(inp.shape)
target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,) target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,)
if end_axis != -1: if end_axis != -1:
target_shape += (*inp.shape[end_axis + 1 :],) target_shape += (*inp.shape[end_axis + 1 :],)
...@@ -1158,6 +1140,5 @@ def cumsum(inp: Tensor, axis: int): ...@@ -1158,6 +1140,5 @@ def cumsum(inp: Tensor, axis: int):
[ 4 9 15]], dtype=int32, device=xpux:0) [ 4 9 15]], dtype=int32, device=xpux:0)
""" """
assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor"
assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis)
op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False)
return apply(op, inp)[0] return apply(op, inp)[0]
...@@ -490,6 +490,84 @@ std::optional<ValueRefList> pixelShuffle_grad_rule( ...@@ -490,6 +490,84 @@ std::optional<ValueRefList> pixelShuffle_grad_rule(
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
std::optional<ValueRefList> indexing_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& indexing = op.cast_final_safe<IndexingOneHot>();
mgb_assert(inputs.size() == 2);
bool flag = inputs_require_grad[0];
auto&& grad_op = IndexingSetOneHot::make(indexing.axis, indexing.ndim);
SmallVector<ValueRef> inputs2;
if (flag) {
inputs2.push_back(get_shape(inputs[0]));
for (size_t i = 1; i < inputs.size(); ++i) {
inputs2.push_back(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inputs2),
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(1);
if (grad && inputs[0]) {
ValueRefList args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
args_[0] = zeros;
args_[1] = inputs[1];
args_[2] = grads[0];
ret[0] = imperative::apply(*grad_op_, args_)[0];
}
return ret;
});
maker.finalize();
return imperative::apply(op, inputs);
}
std::optional<ValueRefList> indexing_set_one_hot_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& indexingSetOneHot = op.cast_final_safe<IndexingSetOneHot>();
mgb_assert(inputs.size() == 3);
SmallVector<ValueRef> inputs2;
inputs2.push_back(get_shape(inputs[0]));
inputs2.push_back(inputs[1]);
inputs2.push_back(get_shape(inputs[2]));
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inputs2),
&indexingSetOneHot](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(3);
if (!grad) {
return ret;
}
if (inputs[0]) {
auto&& grad_op = IndexingSetOneHot::make(
indexingSetOneHot.axis, indexingSetOneHot.ndim);
ValueRefList args_(inputs.size());
auto&& zeros = make_empty_tensor(grad.device(), inputs[2], grad.dtype());
args_[0] = grads[0];
args_[1] = inputs[1];
args_[2] = zeros;
ret[0] = imperative::apply(*grad_op, args_)[0];
}
if (inputs[2]) {
auto&& grad_op = IndexingOneHot::make(
indexingSetOneHot.axis, indexingSetOneHot.ndim);
ValueRefList args_(inputs.size() - 1);
args_[0] = grads[0];
args_[1] = inputs[1];
ret[2] = imperative::apply(*grad_op, args_)[0];
}
return ret;
});
maker.finalize();
return imperative::apply(op, inputs);
}
std::optional<ValueRefList> fastpathcopy_grad_rule( std::optional<ValueRefList> fastpathcopy_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) { CustomBackward& backward) {
...@@ -521,6 +599,10 @@ struct Init { ...@@ -521,6 +599,10 @@ struct Init {
CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule); CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule);
CustomBackward::register_grad_rule( CustomBackward::register_grad_rule(
RemoveAxis::typeinfo(), removeAxis_grad_rule); RemoveAxis::typeinfo(), removeAxis_grad_rule);
CustomBackward::register_grad_rule(
IndexingOneHot::typeinfo(), indexing_grad_rule);
CustomBackward::register_grad_rule(
IndexingSetOneHot::typeinfo(), indexing_set_one_hot_grad_rule);
CustomBackward::register_grad_rule( CustomBackward::register_grad_rule(
FastpathCopy::typeinfo(), fastpathcopy_grad_rule); FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
CustomBackward::register_grad_rule( CustomBackward::register_grad_rule(
......
...@@ -8,11 +8,15 @@ import megengine as mge ...@@ -8,11 +8,15 @@ import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import Tensor
from megengine.core import _imperative_rt
from megengine.core._imperative_rt import CompNode, TensorAttr, imperative from megengine.core._imperative_rt import CompNode, TensorAttr, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.ops import builtin
from megengine.core.ops.builtin import Elemwise, Identity from megengine.core.ops.builtin import Elemwise, Identity
from megengine.functional.distributed import remote_recv, remote_send from megengine.functional.distributed import remote_recv, remote_send
from megengine.functional.tensor import ones, zeros
def _elwise(mode): def _elwise(mode):
...@@ -553,3 +557,46 @@ def test_matmul(): ...@@ -553,3 +557,46 @@ def test_matmul():
if ydim == 1 and transposeB == True: if ydim == 1 and transposeB == True:
continue continue
test_one(xdim, ydim, transposeA, transposeB) test_one(xdim, ydim, transposeA, transposeB)
def test_indexing():
x = np.array([[1.0, 2.0]]).astype("float32")
x = mge.Tensor(x)
index = mge.Tensor([0])
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
def f(x):
return F.indexing_one_hot(x, index, -1)
y = f(x)
grad(y, F.ones_like(y))
np.testing.assert_equal(np.array([[1, 0]], dtype=np.float32), x.grad.numpy())
def test_indexing_set_one_hot():
x = mge.tensor(np.arange(1, 4, dtype=np.int32))
with Grad() as grad:
zeros_tensor = zeros((3, 4), dtype=x.dtype, device=x.device)
ones_tensor = ones((3, 1), dtype=x.dtype, device=x.device)
grad.wrt(zeros_tensor, callback=save_to(zeros_tensor))
grad.wrt(ones_tensor, callback=save_to(ones_tensor))
def f(x):
op = builtin.IndexingSetOneHot(axis=x.ndim, ndim=x.ndim)
(result,) = apply(op, zeros_tensor, x, ones_tensor)
return result
y = f(x)
grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 0, 1, 1], [1, 1, 0, 1], [1, 1, 1, 0]], dtype=np.int32),
zeros_tensor.grad.numpy(),
)
np.testing.assert_equal(
np.array([[1], [1], [1]], dtype=np.int32), ones_tensor.grad.numpy(),
)
...@@ -6,9 +6,7 @@ import pytest ...@@ -6,9 +6,7 @@ import pytest
import megengine.autodiff as ad import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter from megengine import Parameter, Tensor, tensor
from megengine import Tensor as tensor
from megengine import tensor
from megengine.autodiff import Function from megengine.autodiff import Function
from megengine.module import Module from megengine.module import Module
......
...@@ -3,15 +3,15 @@ import numpy as np ...@@ -3,15 +3,15 @@ import numpy as np
import pytest import pytest
import megengine.functional as F import megengine.functional as F
from megengine import tensor import megengine.tensor as Tensor
def test_cross_entropy_with_logits(): def test_cross_entropy_with_logits():
data = tensor([[0, 50], [0, -150]]).astype(np.float32) data = Tensor([[0, 50], [0, -150]]).astype(np.float32)
label = tensor([1, 0]).astype(np.int32) label = Tensor([1, 0]).astype(np.int32)
loss = F.nn.cross_entropy(data, label) loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 0.0) np.testing.assert_allclose(loss.numpy(), 0.0)
label = tensor([0, 1]).astype(np.int32) label = Tensor([0, 1]).astype(np.int32)
loss = F.nn.cross_entropy(data, label) loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 100) np.testing.assert_allclose(loss.numpy(), 100)
...@@ -35,19 +35,24 @@ def test_cross_entropy(): ...@@ -35,19 +35,24 @@ def test_cross_entropy():
x[i, y[i]] += np.random.rand() * 2 x[i, y[i]] += np.random.rand() * 2
x = softmax(x) x = softmax(x)
l_ref = ref(x, y) l_ref = ref(x, y)
l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False) l = F.nn.cross_entropy(Tensor(x, "float32"), Tensor(y, "int32"), with_logits=False)
np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6) np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6)
l1 = F.nn.cross_entropy(
Tensor(x, "float32"), Tensor(y, "int32"), axis=-1, with_logits=False
)
np.testing.assert_allclose(l1.numpy(), l_ref, 1e-6, 1e-6)
def test_cross_entropy_reduction(): def test_cross_entropy_reduction():
logits = np.random.randn(16, 10) logits = np.random.randn(16, 10)
label = np.random.randint(10, size=[16]) label = np.random.randint(10, size=[16])
logits = tensor(logits, dtype="float32") logits = Tensor(logits, dtype="float32")
label = tensor(label, dtype="int32") label = Tensor(label, dtype="int32")
perm = np.random.permutation(16) perm = np.random.permutation(16)
logits_perm = tensor(logits[perm], dtype="float32") logits_perm = Tensor(logits[perm], dtype="float32")
label_perm = tensor(label[perm], dtype="int32") label_perm = Tensor(label[perm], dtype="int32")
loss = F.nn.cross_entropy(logits, label, reduction="none") loss = F.nn.cross_entropy(logits, label, reduction="none")
loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none") loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none")
...@@ -160,18 +165,18 @@ def _ctc_npy_single_seq(pred, label, blank): ...@@ -160,18 +165,18 @@ def _ctc_npy_single_seq(pred, label, blank):
def test_ctc_loss(): def test_ctc_loss():
def test_func(T, C, N): def test_func(T, C, N):
input = np.random.randn(T, N, C) input = np.random.randn(T, N, C)
input = F.softmax(tensor(input), axis=-1).numpy() input = F.softmax(Tensor(input), axis=-1).numpy()
input_lengths = np.ones(N, dtype=np.int32) * T input_lengths = np.ones(N, dtype=np.int32) * T
target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32) target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32)
target = np.random.randint( target = np.random.randint(
low=1, high=C, size=(sum(target_lengths)), dtype=np.int32 low=1, high=C, size=(sum(target_lengths)), dtype=np.int32
) )
input_mge = tensor(input) input_mge = Tensor(input)
input_lengths_mge = tensor(input_lengths) input_lengths_mge = Tensor(input_lengths)
target_mge = tensor(target) target_mge = Tensor(target)
target_lengths_mge = tensor(target_lengths) target_lengths_mge = Tensor(target_lengths)
blank = np.random.randint(C) blank = np.random.randint(C)
for method in ["mean", "sum", "none"]: for method in ["mean", "sum", "none"]:
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from utils import opr_test from utils import opr_test
import megengine.functional as F import megengine.functional as F
from megengine import jit, tensor from megengine import Tensor, jit, tensor
from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin from megengine.core.ops import builtin
...@@ -61,37 +61,84 @@ def common_test_reduce(opr, ref_opr): ...@@ -61,37 +61,84 @@ def common_test_reduce(opr, ref_opr):
def test_sum(): def test_sum():
common_test_reduce(opr=F.sum, ref_opr=np.sum) common_test_reduce(opr=F.sum, ref_opr=np.sum)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.sum(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([6, 15]).astype(np.int32))
def test_prod(): def test_prod():
common_test_reduce(opr=F.prod, ref_opr=np.prod) common_test_reduce(opr=F.prod, ref_opr=np.prod)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.prod(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([4, 10, 18]).astype(np.int32))
def test_mean(): def test_mean():
common_test_reduce(opr=F.mean, ref_opr=np.mean) common_test_reduce(opr=F.mean, ref_opr=np.mean)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.mean(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([2.5, 3.5, 4.5]).astype(np.float32))
def test_var(): def test_var():
common_test_reduce(opr=F.var, ref_opr=np.var) common_test_reduce(opr=F.var, ref_opr=np.var)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.var(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([2.25, 2.25, 2.25]).astype(np.float32))
def test_std(): def test_std():
common_test_reduce(opr=F.std, ref_opr=np.std) common_test_reduce(opr=F.std, ref_opr=np.std)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.std(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32))
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.std(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32))
def test_min(): def test_min():
common_test_reduce(opr=F.min, ref_opr=np.min) common_test_reduce(opr=F.min, ref_opr=np.min)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.min(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([1, 4]).astype(np.int32))
def test_max(): def test_max():
common_test_reduce(opr=F.max, ref_opr=np.max) common_test_reduce(opr=F.max, ref_opr=np.max)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.max(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([3, 6]).astype(np.int32))
def test_argmin(): def test_argmin():
common_test_reduce(opr=F.argmin, ref_opr=np.argmin) common_test_reduce(opr=F.argmin, ref_opr=np.argmin)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.argmin(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([0, 0]).astype(np.int32))
def test_argmax(): def test_argmax():
common_test_reduce(opr=F.argmax, ref_opr=np.argmax) common_test_reduce(opr=F.argmax, ref_opr=np.argmax)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.argmax(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([1, 1, 1]).astype(np.int32))
def test_norm():
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.norm(x, axis=-1)
np.testing.assert_equal(
y.numpy().round(decimals=3), np.array([3.742, 8.775]).astype(np.float32)
)
def test_sqrt(): def test_sqrt():
...@@ -136,7 +183,7 @@ def test_sort_empty(is_symbolic): ...@@ -136,7 +183,7 @@ def test_sort_empty(is_symbolic):
fn_ = fn fn_ = fn
data = np.random.random(shape).astype(np.float32) data = np.random.random(shape).astype(np.float32)
for _ in range(3): for _ in range(3):
outs = fn_(tensor(data)) outs = fn_(Tensor(data))
ref_outs = (np.sort(data), np.argsort(data)) ref_outs = (np.sort(data), np.argsort(data))
assert len(ref_outs) == len(outs) assert len(ref_outs) == len(outs)
for i in range(len(outs)): for i in range(len(outs)):
...@@ -146,6 +193,12 @@ def test_sort_empty(is_symbolic): ...@@ -146,6 +193,12 @@ def test_sort_empty(is_symbolic):
def test_normalize(): def test_normalize():
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.normalize(x, axis=-1)
np.testing.assert_equal(
y.numpy().round(decimals=1),
np.array([[0.3, 0.5, 0.8], [0.5, 0.6, 0.7]]).astype(np.float32),
)
cases = [ cases = [
{"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2) {"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2)
...@@ -177,11 +230,11 @@ def test_sum_neg_axis(): ...@@ -177,11 +230,11 @@ def test_sum_neg_axis():
shape = (2, 3) shape = (2, 3)
data = np.random.random(shape).astype(np.float32) data = np.random.random(shape).astype(np.float32)
for axis in (-1, -2, (-2, 1), (-1, 0)): for axis in (-1, -2, (-2, 1), (-1, 0)):
get = F.sum(tensor(data), axis=axis) get = F.sum(Tensor(data), axis=axis)
ref = np.sum(data, axis=axis) ref = np.sum(data, axis=axis)
np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
F.sum(tensor(data), axis=(-1, 1)) F.sum(Tensor(data), axis=(-1, 1))
def test_builtin_reduce(): def test_builtin_reduce():
...@@ -204,18 +257,18 @@ def test_non_finite(): ...@@ -204,18 +257,18 @@ def test_non_finite():
data = [] data = []
for i in range(2): for i in range(2):
data.append(np.random.random(shape).astype(np.float32)) data.append(np.random.random(shape).astype(np.float32))
tensorList = [tensor(x) for x in data] tensorList = [Tensor(x) for x in data]
rst = F.math._check_non_finite(tensorList, 0.7) rst = F.math._check_non_finite(tensorList, 0.7)
np.testing.assert_equal(rst.numpy(), [0]) np.testing.assert_equal(rst.numpy(), [0])
for i in range(len(tensorList)): for i in range(len(tensorList)):
np.testing.assert_allclose(tensorList[i].numpy() / 0.7, data[i], rtol=1e-6) np.testing.assert_allclose(tensorList[i].numpy() / 0.7, data[i], rtol=1e-6)
data[1][0][0][0][0] = float("inf") data[1][0][0][0][0] = float("inf")
rst = F.math._check_non_finite([tensor(x) for x in data], 0.7) rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7)
np.testing.assert_equal(rst.numpy(), [1]) np.testing.assert_equal(rst.numpy(), [1])
data[1][0][0][0][0] = float("nan") data[1][0][0][0][0] = float("nan")
rst = F.math._check_non_finite([tensor(x) for x in data], 0.7) rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7)
np.testing.assert_equal(rst.numpy(), [1]) np.testing.assert_equal(rst.numpy(), [1])
...@@ -237,7 +290,7 @@ def test_topk(descending, sorted, inp1d, kth_only): ...@@ -237,7 +290,7 @@ def test_topk(descending, sorted, inp1d, kth_only):
return np.sort(x) return np.sort(x)
res = F.topk( res = F.topk(
tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only Tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only
) )
values, indices = res values, indices = res
...@@ -268,7 +321,7 @@ def test_reduce_on_empty_tensor(is_trace): ...@@ -268,7 +321,7 @@ def test_reduce_on_empty_tensor(is_trace):
if is_trace: if is_trace:
fn = jit.trace(symbolic=symbolic)(fn) fn = jit.trace(symbolic=symbolic)(fn)
for i in range(3): for i in range(3):
out = fn(tensor(input, dtype=dtype), axis=axis).numpy() out = fn(Tensor(input, dtype=dtype), axis=axis).numpy()
out_ref = ref_fn(input.astype(dtype), axis=axis) out_ref = ref_fn(input.astype(dtype), axis=axis)
np.testing.assert_equal(out, out_ref) np.testing.assert_equal(out, out_ref)
......
...@@ -7,7 +7,7 @@ import pytest ...@@ -7,7 +7,7 @@ import pytest
from utils import get_var_value, make_tensor, opr_test from utils import get_var_value, make_tensor, opr_test
import megengine.functional as F import megengine.functional as F
from megengine import tensor from megengine import Tensor
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.utils import astensor1d from megengine.core.tensor.utils import astensor1d
...@@ -30,7 +30,7 @@ def test_eye(): ...@@ -30,7 +30,7 @@ def test_eye():
np.eye(*case["input"]).astype(dtype), np.eye(*case["input"]).astype(dtype),
) )
np.testing.assert_allclose( np.testing.assert_allclose(
F.eye(tensor(case["input"]), dtype=dtype).numpy(), F.eye(Tensor(case["input"]), dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype), np.eye(*case["input"]).astype(dtype),
) )
...@@ -60,7 +60,21 @@ def test_full(): ...@@ -60,7 +60,21 @@ def test_full():
values = [True, 4, 5.0] values = [True, 4, 5.0]
for value in values: for value in values:
np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value)) np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
assert F.full(shape, value).dtype == tensor(value).dtype assert F.full(shape, value).dtype == Tensor(value).dtype
@pytest.mark.parametrize("is_varnode", [True, False])
def test_cumsum(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = Tensor([[1, 2, 3], [4, 5, 6]], np.int32)
y = F.cumsum(x, -1)
np.testing.assert_equal(
y.numpy(), np.array([[1, 3, 6], [4, 9, 15]]).astype(np.int32)
)
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
...@@ -83,6 +97,14 @@ def test_concat(is_varnode): ...@@ -83,6 +97,14 @@ def test_concat(is_varnode):
cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
x1 = Tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
x2 = Tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
y = F.concat([x1, x2], axis=-1)
np.testing.assert_equal(
y.numpy(),
np.array([[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11]]).astype(np.float32),
)
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_condtake(is_varnode): def test_condtake(is_varnode):
...@@ -139,6 +161,20 @@ def test_stack(is_varnode): ...@@ -139,6 +161,20 @@ def test_stack(is_varnode):
cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
) )
x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
y = F.stack([x1, x2], axis=-1)
np.testing.assert_equal(
y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32)
)
x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
y = F.stack([x1, x2], axis=-1)
np.testing.assert_equal(
y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32)
)
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_split_basic(is_varnode): def test_split_basic(is_varnode):
...@@ -183,6 +219,12 @@ def test_split_basic(is_varnode): ...@@ -183,6 +219,12 @@ def test_split_basic(is_varnode):
@pytest.mark.parametrize("symbolic", [None, False, True]) @pytest.mark.parametrize("symbolic", [None, False, True])
def test_split(symbolic): def test_split(symbolic):
x = Tensor(np.random.random((10, 20)), dtype=np.float32)
y = F.split(x, 3, axis=-1)
z = F.split(x, [6, 17], axis=-1)
assert str([i.numpy().shape for i in y]) == "[(10, 7), (10, 7), (10, 6)]"
assert str([i.numpy().shape for i in z]) == "[(10, 6), (10, 11), (10, 3)]"
inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32) inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32) inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
...@@ -208,12 +250,43 @@ def test_split(symbolic): ...@@ -208,12 +250,43 @@ def test_split(symbolic):
fn = trace(symbolic=symbolic)(func) fn = trace(symbolic=symbolic)(func)
for i in range(3 if symbolic is not None else 1): for i in range(3 if symbolic is not None else 1):
ref_out = ref(*case) ref_out = ref(*case)
out = fn(tensor(case[0]), case[1], case[2]) out = fn(Tensor(case[0]), case[1], case[2])
assert len(ref_out) == len(out) assert len(ref_out) == len(out)
for idx in range(len(ref_out)): for idx in range(len(ref_out)):
np.testing.assert_equal(ref_out[idx], out[idx].numpy()) np.testing.assert_equal(ref_out[idx], out[idx].numpy())
def test_gather():
x = Tensor([[1, 2], [3, 4], [5, 6],])
index = Tensor([[0, 1], [1, 0], [1, 1]])
y = F.gather(x, 1, index)
np.testing.assert_equal(
y.numpy(), np.array([[1, 2], [4, 3], [6, 6]]).astype(np.int32)
)
def test_scatter():
x = Tensor(np.zeros(shape=(3, 5), dtype=np.float32))
source = Tensor(
[
[0.9935, 0.9465, 0.2256, 0.8926, 0.4396],
[0.7723, 0.0718, 0.5939, 0.357, 0.4576],
]
)
index = Tensor([[0, 2, 0, 2, 1], [2, 0, 1, 1, 2]])
y = F.scatter(x, -2, index, source)
np.testing.assert_equal(
y.numpy().round(decimals=4),
np.array(
[
[0.9935, 0.0718, 0.2256, 0.0, 0.0],
[0.0, 0.0, 0.5939, 0.357, 0.4396],
[0.7723, 0.9465, 0.0, 0.8926, 0.4576],
]
).astype(np.float32),
)
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_swapaxes(is_varnode): def test_swapaxes(is_varnode):
if is_varnode: if is_varnode:
...@@ -221,7 +294,7 @@ def test_swapaxes(is_varnode): ...@@ -221,7 +294,7 @@ def test_swapaxes(is_varnode):
else: else:
network = None network = None
x = tensor(np.array([[1, 2, 3]], dtype=np.int32)) x = Tensor(np.array([[1, 2, 3]], dtype=np.int32))
y = F.swapaxes(x, 0, 1) y = F.swapaxes(x, 0, 1)
np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32)) np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32))
...@@ -280,15 +353,15 @@ def test_broadcast_auto_infer(is_varnode): ...@@ -280,15 +353,15 @@ def test_broadcast_auto_infer(is_varnode):
def test_reshape_on_empty_tensor(is_trace): def test_reshape_on_empty_tensor(is_trace):
input1_shape = (100, 0, 1) input1_shape = (100, 0, 1)
output1_shape = (100, 0, 10) output1_shape = (100, 0, 10)
data1 = tensor(np.random.random(input1_shape).astype(np.float32)) data1 = Tensor(np.random.random(input1_shape).astype(np.float32))
input2_shape = (10, 0) input2_shape = (10, 0)
output2_shape = (0,) output2_shape = (0,)
data2 = tensor(np.random.random(input2_shape).astype(np.float32)) data2 = Tensor(np.random.random(input2_shape).astype(np.float32))
input3_shape = (10, 0, 10) input3_shape = (10, 0, 10)
output3_shape = (0, 1, 2, 3) output3_shape = (0, 1, 2, 3)
data3 = tensor(np.random.random(input3_shape).astype(np.float32)) data3 = Tensor(np.random.random(input3_shape).astype(np.float32))
def comp(out, target_shp): def comp(out, target_shp):
assert out._tuple_shape == target_shp assert out._tuple_shape == target_shp
...@@ -338,7 +411,7 @@ def test_reshape_shape_inference(is_varnode): ...@@ -338,7 +411,7 @@ def test_reshape_shape_inference(is_varnode):
def check_shape(output, target): def check_shape(output, target):
source = output.shape source = output.shape
if isinstance(source, tensor): if isinstance(source, Tensor):
source = source.numpy() source = source.numpy()
np.testing.assert_equal(source, target.shape) np.testing.assert_equal(source, target.shape)
...@@ -366,6 +439,10 @@ def test_squeeze(is_varnode): ...@@ -366,6 +439,10 @@ def test_squeeze(is_varnode):
else: else:
network = None network = None
x = Tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
y = F.squeeze(x, -1)
np.testing.assert_equal(y.numpy(), np.array([[[1, 2]]]).astype(np.int32))
x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
xx = make_tensor(x, network) xx = make_tensor(x, network)
...@@ -385,6 +462,12 @@ def test_expand_dims(is_varnode): ...@@ -385,6 +462,12 @@ def test_expand_dims(is_varnode):
else: else:
network = None network = None
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.expand_dims(x, -1)
np.testing.assert_equal(
y.numpy(), np.array([[[1], [2], [3]], [[4], [5], [6]]]).astype(np.int32)
)
x = np.arange(6, dtype="float32").reshape(2, 3) x = np.arange(6, dtype="float32").reshape(2, 3)
xx = make_tensor(x, network) xx = make_tensor(x, network)
...@@ -533,6 +616,22 @@ def test_flatten(is_varnode): ...@@ -533,6 +616,22 @@ def test_flatten(is_varnode):
else: else:
network = None network = None
inp_shape = (2, 2, 3, 3)
x = Tensor(np.arange(36, dtype=np.int32).reshape(inp_shape),)
y = F.flatten(x, -2, -1)
np.testing.assert_equal(
y.numpy(),
np.array(
[
[[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16, 17]],
[
[18, 19, 20, 21, 22, 23, 24, 25, 26],
[27, 28, 29, 30, 31, 32, 33, 34, 35],
],
]
).astype(np.int32),
)
data0_shape = (2, 3, 4, 5) data0_shape = (2, 3, 4, 5)
data1_shape = (4, 5, 6, 7) data1_shape = (4, 5, 6, 7)
data0 = np.random.random(data0_shape).astype(np.float32) data0 = np.random.random(data0_shape).astype(np.float32)
...@@ -616,15 +715,15 @@ def test_broadcast(is_varnode): ...@@ -616,15 +715,15 @@ def test_broadcast(is_varnode):
def test_broadcast_on_empty_tensor(is_trace): def test_broadcast_on_empty_tensor(is_trace):
input1_shape = (100, 0, 1) input1_shape = (100, 0, 1)
output1_shape = (100, 0, 10) output1_shape = (100, 0, 10)
data1 = tensor(np.random.random(input1_shape).astype(np.float32)) data1 = Tensor(np.random.random(input1_shape).astype(np.float32))
input2_shape = (10, 0) input2_shape = (10, 0)
output2_shape = (10, 10, 0) output2_shape = (10, 10, 0)
data2 = tensor(np.random.random(input2_shape).astype(np.float32)) data2 = Tensor(np.random.random(input2_shape).astype(np.float32))
input3_shape = (0, 0, 1, 10) input3_shape = (0, 0, 1, 10)
output3_shape = (10, 0, 0, 10, 10) output3_shape = (10, 0, 0, 10, 10)
data3 = tensor(np.random.random(input3_shape).astype(np.float32)) data3 = Tensor(np.random.random(input3_shape).astype(np.float32))
def comp(out, target_shp): def comp(out, target_shp):
assert out._tuple_shape == target_shp assert out._tuple_shape == target_shp
...@@ -705,7 +804,7 @@ def test_utils_astensor1d(is_varnode): ...@@ -705,7 +804,7 @@ def test_utils_astensor1d(is_varnode):
def test_device(): def test_device():
x = tensor([1, 2, 3], dtype="float32") x = Tensor([1, 2, 3], dtype="float32")
y1 = F.eye(x.shape, dtype="float32") y1 = F.eye(x.shape, dtype="float32")
y2 = F.eye(x.shape, dtype="float32", device=None) y2 = F.eye(x.shape, dtype="float32", device=None)
...@@ -789,7 +888,7 @@ def test_copy_d2d(is_varnode): ...@@ -789,7 +888,7 @@ def test_copy_d2d(is_varnode):
) )
@pytest.mark.parametrize("is_symbolic", [None, True, False]) @pytest.mark.parametrize("is_symbolic", [None, True, False])
def test_copy_empty(shape, device_src, device_dst, is_symbolic): def test_copy_empty(shape, device_src, device_dst, is_symbolic):
inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src) inp = Tensor(np.random.randn(*shape).astype("float32"), device=device_src)
def func(inp): def func(inp):
return F.copy(inp, device_dst) return F.copy(inp, device_dst)
...@@ -885,6 +984,12 @@ def test_roll(shape, shifts, axis, is_varnode): ...@@ -885,6 +984,12 @@ def test_roll(shape, shifts, axis, is_varnode):
else: else:
network = None network = None
x = Tensor([[1, 2], [3, 4], [5, 6]], np.int32)
y = F.roll(x, 1, -1)
np.testing.assert_equal(
y.numpy(), np.array([[2, 1], [4, 3], [6, 5]]).astype(np.int32)
)
inp = np.random.randn(*shape).astype("float32") inp = np.random.randn(*shape).astype("float32")
def func(inp): def func(inp):
...@@ -904,7 +1009,7 @@ def test_roll(shape, shifts, axis, is_varnode): ...@@ -904,7 +1009,7 @@ def test_roll(shape, shifts, axis, is_varnode):
) )
@pytest.mark.parametrize("is_symbolic", [None, True, False]) @pytest.mark.parametrize("is_symbolic", [None, True, False])
def test_roll_empty_tensor(shape, shifts, axis, is_symbolic): def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
inp = tensor(np.random.randn(*shape).astype("float32")) inp = Tensor(np.random.randn(*shape).astype("float32"))
def func(inp): def func(inp):
return F.roll(inp, shifts, axis) return F.roll(inp, shifts, axis)
......
#include "../dnn_op_helper.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "../op_trait.h" #include "../op_trait.h"
#include "megbrain/opr/indexing.h" #include "megbrain/opr/indexing.h"
#include "megdnn/oprs/general.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
...@@ -12,10 +14,8 @@ namespace indexing_one_hot { ...@@ -12,10 +14,8 @@ namespace indexing_one_hot {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
auto& op = def.cast_final_safe<IndexingOneHot>(); auto&& op = def.cast_final_safe<IndexingOneHot>();
mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs"); mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs");
auto comp_node = input_descs[0].comp_node; auto comp_node = input_descs[0].comp_node;
TensorLayout src = input_descs[0].layout, index = input_descs[1].layout; TensorLayout src = input_descs[0].layout, index = input_descs[1].layout;
...@@ -28,10 +28,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -28,10 +28,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert(src.ndim >= 2, "src ndim must be at least 2"); mgb_assert(src.ndim >= 2, "src ndim must be at least 2");
mgb_assert(src.is_contiguous(), "src should be contiguous"); mgb_assert(src.is_contiguous(), "src should be contiguous");
mgb_assert( mgb_assert(
op.axis >= 0 && op.axis < src.ndim, "axis %d not exists in src", op.axis); -static_cast<int>(src.ndim) <= op.axis &&
op.axis < static_cast<int>(src.ndim),
"axis %d not exists in src", op.axis);
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(src.ndim);
}
TensorLayout dst = src; TensorLayout dst = src;
dst.shape[op.axis] = 1; dst.shape[real_axis] = 1;
dst.init_contiguous_stride(); dst.init_contiguous_stride();
if (!index.ndim) { if (!index.ndim) {
...@@ -40,24 +45,128 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -40,24 +45,128 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert(index.is_contiguous(), "index should be all contiguous"); mgb_assert(index.is_contiguous(), "index should be all contiguous");
mgb_assert( mgb_assert(
index.eq_shape(src.remove_axis(op.axis)), "index shape doesn't match src"); index.eq_shape(src.remove_axis(real_axis)),
"index shape doesn't match src");
return {{{dst, comp_node}}, true}; return {{{dst, comp_node}}, true};
} }
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def); auto&& op = def.cast_final_safe<IndexingOneHot>();
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(op.ndim);
}
OperatorNodeConfig config{op.make_name()}; OperatorNodeConfig config{op.make_name()};
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); return opr::IndexingOneHot::make(inputs[0], inputs[1], real_axis, config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<IndexingOneHot>();
auto&& inp = inputs[0];
auto&& index = inputs[1];
TensorLayout layout = inp->layout();
TensorLayout index_layout = index->layout();
DnnOprCaller<megdnn::IndexingOneHot> dnn_op(inp->comp_node());
auto&& indexing_one_hot_param = dnn_op.op->param();
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(layout.ndim);
}
mgb_assert(
0 <= real_axis && real_axis < static_cast<int>(layout.ndim),
"Dimension out of range (expected to be in range of [%d, %d], but got %d)",
0, static_cast<int>(layout.ndim) - 1, op.axis);
indexing_one_hot_param = real_axis;
TensorLayout tlayout;
dnn_op.op->deduce_layout(layout, index_layout, tlayout);
TensorPtr out = Tensor::make(tlayout, inp->comp_node());
megdnn::TensorND in = inp->dnn_tensor();
megdnn::TensorND ind = index->dnn_tensor();
TensorLayout m_layout(
{dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)},
dtype::Byte());
auto dnn_workspace = dnn_op.create_workspace(m_layout);
dnn_op.op->exec(in, ind, out->dnn_tensor(), dnn_workspace);
return {out};
} }
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback(); .fallback();
} // namespace indexing_one_hot } // namespace indexing_one_hot
namespace indexing_set_one_hot {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
mgb_assert(input_descs.size() == 3, "IndexingSetOneHot expects three inputs");
auto comp_node = input_descs[0].comp_node;
TensorLayout src = input_descs[0].layout, index = input_descs[1].layout;
mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32");
if (!src.ndim) {
return {{{{{}, src.dtype}, comp_node}}, false};
}
mgb_assert(src.is_contiguous(), "src should be contiguous");
return {{input_descs[0]}, true};
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def);
mgb_assert(inputs.size() == 3);
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(op.ndim);
}
OperatorNodeConfig config{op.make_name()};
return opr::IndexingSetOneHot::make(
inputs[0], inputs[1], inputs[2], real_axis, config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<IndexingSetOneHot>();
auto&& inp = inputs[0];
auto&& index = inputs[1];
auto&& sub = inputs[2];
TensorLayout layout = inp->layout();
TensorLayout index_layout = index->layout();
TensorLayout tlayout = sub->layout();
mgb_assert(layout.is_contiguous());
DnnOprCaller<megdnn::IndexingSetOneHot> dnn_op(inp->comp_node());
auto&& indexing_one_hot_param = dnn_op.op->param();
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(layout.ndim);
}
indexing_one_hot_param = real_axis;
TensorPtr out = Tensor::make(layout, inp->comp_node());
out->dev_tensor().copy_from_fixlayout(inp->dev_tensor());
megdnn::TensorND in = inp->dnn_tensor();
megdnn::TensorND ind = index->dnn_tensor();
megdnn::TensorND su = sub->dnn_tensor();
TensorLayout m_layout(
{dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)},
dtype::Byte());
auto dnn_workspace = dnn_op.create_workspace(m_layout);
dnn_op.op->exec(out->dnn_tensor(), ind, su, dnn_workspace);
return {out};
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace indexing_set_one_hot
} // anonymous namespace } // anonymous namespace
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -372,21 +372,6 @@ OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallba ...@@ -372,21 +372,6 @@ OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallba
} // namespace group_local } // namespace group_local
} // namespace } // namespace
namespace {
namespace indexing_set_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def);
mgb_assert(inputs.size() == 3);
OperatorNodeConfig config{op.make_name()};
return opr::IndexingSetOneHot::make(
inputs[0], inputs[1], inputs[2], op.param(), config);
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace indexing_set_one_hot
} // namespace
namespace { namespace {
namespace typecvt { namespace typecvt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
......
...@@ -108,9 +108,17 @@ def Remap: MgbHashableOp<"Remap", [RemapParam]>; ...@@ -108,9 +108,17 @@ def Remap: MgbHashableOp<"Remap", [RemapParam]>;
def Resize: MgbHashableOp<"Resize", [ResizeParam]>; def Resize: MgbHashableOp<"Resize", [ResizeParam]>;
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>; def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]> {
let extraArguments = (ins
MgbI32Attr:$ndim
);
}
def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>; def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]> {
let extraArguments = (ins
MgbI32Attr:$ndim
);
}
def Copy: MgbHashableOp<"Copy"> { def Copy: MgbHashableOp<"Copy"> {
let extraArguments = (ins let extraArguments = (ins
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册