提交 24c5c19b 编写于 作者: M Megvii Engine Team

fix(imperative): make functional ops support negative axis

GitOrigin-RevId: f61e01270b948ab5bd6ba32b091d7c6b8d7a0745
上级 c76e80bc
......@@ -1015,7 +1015,7 @@ class IndexingOneHotBase : public OperatorBase {
DEF_OPR_PARAM(Axis);
protected:
void deduce_layout_fwd(
MGE_WIN_DECLSPEC_FUC void deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& index, TensorLayout& dst);
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& index,
......
......@@ -1558,7 +1558,7 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
)
ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device)
op = builtin.IndexingSetOneHot(axis=inp.ndim)
op = builtin.IndexingSetOneHot(axis=inp.ndim, ndim=inp.ndim)
(result,) = apply(op, zeros_tensor, inp, ones_tensor)
return result
......@@ -1609,7 +1609,7 @@ def indexing_one_hot(
array([1.], dtype=float32)
"""
assert isinstance(src, Tensor), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis)
op = builtin.IndexingOneHot(axis=axis, ndim=src.ndim)
index = convert_single_value(index, dtype="int32", device=src.device)
(result,) = apply(op, src, index)
if not keepdims:
......
......@@ -393,6 +393,8 @@ def split(inp, nsplits_or_sections, axis=0):
def _get_idx(index, axis):
index_dims = len(index.shape)
idx = []
if axis < 0:
axis += index_dims
for i in range(index_dims):
if i != axis:
shape = [1] * index_dims
......@@ -457,21 +459,6 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
"But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
)
if axis < 0 or axis >= input_dims:
raise ValueError(
"Index axis {} is output of bounds, should in range [0 {})".format(
axis, input_dims
)
)
for i in range(input_dims):
if i != axis and input_shape[i] != index_shape[i]:
raise ValueError(
"The input {} and index {} must have the same size apart from axis {}".format(
input_shape, index_shape, axis
)
)
idx = _get_idx(index, axis)
return inp[idx].reshape(index.shape) # pylint: disable=no-member
......@@ -524,7 +511,7 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
>>> inp = Tensor(np.zeros(shape=(3,5),dtype=np.float32))
>>> source = Tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
>>> index = Tensor([[0,2,0,2,1],[2,0,1,1,2]])
>>> oup = F.scatter(inp, 0, index,source)
>>> oup = F.scatter(inp, 0, index, source)
>>> oup.numpy()
array([[0.9935, 0.0718, 0.2256, 0. , 0. ],
[0. , 0. , 0.5939, 0.357 , 0.4396],
......@@ -540,13 +527,6 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
if input_dims != index_dims or input_dims != source_dims:
raise ValueError("The input, source and index tensor must have same dimensions")
if axis < 0 or axis >= input_dims:
raise ValueError(
"Index axis {} is output of bounds, should in range [0 {})".format(
axis, input_dims
)
)
for i in range(source_dims):
if source_shape[i] > input_shape[i]:
raise ValueError(
......@@ -792,6 +772,8 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
>>> out.numpy().shape
(2, 2, 9)
"""
if start_axis < 0:
start_axis += len(inp.shape)
target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,)
if end_axis != -1:
target_shape += (*inp.shape[end_axis + 1 :],)
......@@ -1158,6 +1140,5 @@ def cumsum(inp: Tensor, axis: int):
[ 4 9 15]], dtype=int32, device=xpux:0)
"""
assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor"
assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis)
op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False)
return apply(op, inp)[0]
......@@ -490,6 +490,84 @@ std::optional<ValueRefList> pixelShuffle_grad_rule(
return imperative::apply(op, inputs);
}
std::optional<ValueRefList> indexing_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& indexing = op.cast_final_safe<IndexingOneHot>();
mgb_assert(inputs.size() == 2);
bool flag = inputs_require_grad[0];
auto&& grad_op = IndexingSetOneHot::make(indexing.axis, indexing.ndim);
SmallVector<ValueRef> inputs2;
if (flag) {
inputs2.push_back(get_shape(inputs[0]));
for (size_t i = 1; i < inputs.size(); ++i) {
inputs2.push_back(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inputs2),
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(1);
if (grad && inputs[0]) {
ValueRefList args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
args_[0] = zeros;
args_[1] = inputs[1];
args_[2] = grads[0];
ret[0] = imperative::apply(*grad_op_, args_)[0];
}
return ret;
});
maker.finalize();
return imperative::apply(op, inputs);
}
std::optional<ValueRefList> indexing_set_one_hot_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& indexingSetOneHot = op.cast_final_safe<IndexingSetOneHot>();
mgb_assert(inputs.size() == 3);
SmallVector<ValueRef> inputs2;
inputs2.push_back(get_shape(inputs[0]));
inputs2.push_back(inputs[1]);
inputs2.push_back(get_shape(inputs[2]));
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inputs2),
&indexingSetOneHot](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(3);
if (!grad) {
return ret;
}
if (inputs[0]) {
auto&& grad_op = IndexingSetOneHot::make(
indexingSetOneHot.axis, indexingSetOneHot.ndim);
ValueRefList args_(inputs.size());
auto&& zeros = make_empty_tensor(grad.device(), inputs[2], grad.dtype());
args_[0] = grads[0];
args_[1] = inputs[1];
args_[2] = zeros;
ret[0] = imperative::apply(*grad_op, args_)[0];
}
if (inputs[2]) {
auto&& grad_op = IndexingOneHot::make(
indexingSetOneHot.axis, indexingSetOneHot.ndim);
ValueRefList args_(inputs.size() - 1);
args_[0] = grads[0];
args_[1] = inputs[1];
ret[2] = imperative::apply(*grad_op, args_)[0];
}
return ret;
});
maker.finalize();
return imperative::apply(op, inputs);
}
std::optional<ValueRefList> fastpathcopy_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
......@@ -521,6 +599,10 @@ struct Init {
CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule);
CustomBackward::register_grad_rule(
RemoveAxis::typeinfo(), removeAxis_grad_rule);
CustomBackward::register_grad_rule(
IndexingOneHot::typeinfo(), indexing_grad_rule);
CustomBackward::register_grad_rule(
IndexingSetOneHot::typeinfo(), indexing_set_one_hot_grad_rule);
CustomBackward::register_grad_rule(
FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
CustomBackward::register_grad_rule(
......
......@@ -8,11 +8,15 @@ import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.core import _imperative_rt
from megengine.core._imperative_rt import CompNode, TensorAttr, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync
from megengine.core.autodiff.grad import Grad
from megengine.core.ops import builtin
from megengine.core.ops.builtin import Elemwise, Identity
from megengine.functional.distributed import remote_recv, remote_send
from megengine.functional.tensor import ones, zeros
def _elwise(mode):
......@@ -553,3 +557,46 @@ def test_matmul():
if ydim == 1 and transposeB == True:
continue
test_one(xdim, ydim, transposeA, transposeB)
def test_indexing():
x = np.array([[1.0, 2.0]]).astype("float32")
x = mge.Tensor(x)
index = mge.Tensor([0])
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
def f(x):
return F.indexing_one_hot(x, index, -1)
y = f(x)
grad(y, F.ones_like(y))
np.testing.assert_equal(np.array([[1, 0]], dtype=np.float32), x.grad.numpy())
def test_indexing_set_one_hot():
x = mge.tensor(np.arange(1, 4, dtype=np.int32))
with Grad() as grad:
zeros_tensor = zeros((3, 4), dtype=x.dtype, device=x.device)
ones_tensor = ones((3, 1), dtype=x.dtype, device=x.device)
grad.wrt(zeros_tensor, callback=save_to(zeros_tensor))
grad.wrt(ones_tensor, callback=save_to(ones_tensor))
def f(x):
op = builtin.IndexingSetOneHot(axis=x.ndim, ndim=x.ndim)
(result,) = apply(op, zeros_tensor, x, ones_tensor)
return result
y = f(x)
grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 0, 1, 1], [1, 1, 0, 1], [1, 1, 1, 0]], dtype=np.int32),
zeros_tensor.grad.numpy(),
)
np.testing.assert_equal(
np.array([[1], [1], [1]], dtype=np.int32), ones_tensor.grad.numpy(),
)
......@@ -6,9 +6,7 @@ import pytest
import megengine.autodiff as ad
import megengine.functional as F
import megengine.optimizer as optimizer
from megengine import Parameter
from megengine import Tensor as tensor
from megengine import tensor
from megengine import Parameter, Tensor, tensor
from megengine.autodiff import Function
from megengine.module import Module
......
......@@ -3,15 +3,15 @@ import numpy as np
import pytest
import megengine.functional as F
from megengine import tensor
import megengine.tensor as Tensor
def test_cross_entropy_with_logits():
data = tensor([[0, 50], [0, -150]]).astype(np.float32)
label = tensor([1, 0]).astype(np.int32)
data = Tensor([[0, 50], [0, -150]]).astype(np.float32)
label = Tensor([1, 0]).astype(np.int32)
loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 0.0)
label = tensor([0, 1]).astype(np.int32)
label = Tensor([0, 1]).astype(np.int32)
loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 100)
......@@ -35,19 +35,24 @@ def test_cross_entropy():
x[i, y[i]] += np.random.rand() * 2
x = softmax(x)
l_ref = ref(x, y)
l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False)
l = F.nn.cross_entropy(Tensor(x, "float32"), Tensor(y, "int32"), with_logits=False)
np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6)
l1 = F.nn.cross_entropy(
Tensor(x, "float32"), Tensor(y, "int32"), axis=-1, with_logits=False
)
np.testing.assert_allclose(l1.numpy(), l_ref, 1e-6, 1e-6)
def test_cross_entropy_reduction():
logits = np.random.randn(16, 10)
label = np.random.randint(10, size=[16])
logits = tensor(logits, dtype="float32")
label = tensor(label, dtype="int32")
logits = Tensor(logits, dtype="float32")
label = Tensor(label, dtype="int32")
perm = np.random.permutation(16)
logits_perm = tensor(logits[perm], dtype="float32")
label_perm = tensor(label[perm], dtype="int32")
logits_perm = Tensor(logits[perm], dtype="float32")
label_perm = Tensor(label[perm], dtype="int32")
loss = F.nn.cross_entropy(logits, label, reduction="none")
loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none")
......@@ -160,18 +165,18 @@ def _ctc_npy_single_seq(pred, label, blank):
def test_ctc_loss():
def test_func(T, C, N):
input = np.random.randn(T, N, C)
input = F.softmax(tensor(input), axis=-1).numpy()
input = F.softmax(Tensor(input), axis=-1).numpy()
input_lengths = np.ones(N, dtype=np.int32) * T
target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32)
target = np.random.randint(
low=1, high=C, size=(sum(target_lengths)), dtype=np.int32
)
input_mge = tensor(input)
input_lengths_mge = tensor(input_lengths)
input_mge = Tensor(input)
input_lengths_mge = Tensor(input_lengths)
target_mge = tensor(target)
target_lengths_mge = tensor(target_lengths)
target_mge = Tensor(target)
target_lengths_mge = Tensor(target_lengths)
blank = np.random.randint(C)
for method in ["mean", "sum", "none"]:
......
......@@ -6,7 +6,7 @@ import pytest
from utils import opr_test
import megengine.functional as F
from megengine import jit, tensor
from megengine import Tensor, jit, tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
......@@ -61,37 +61,84 @@ def common_test_reduce(opr, ref_opr):
def test_sum():
common_test_reduce(opr=F.sum, ref_opr=np.sum)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.sum(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([6, 15]).astype(np.int32))
def test_prod():
common_test_reduce(opr=F.prod, ref_opr=np.prod)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.prod(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([4, 10, 18]).astype(np.int32))
def test_mean():
common_test_reduce(opr=F.mean, ref_opr=np.mean)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.mean(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([2.5, 3.5, 4.5]).astype(np.float32))
def test_var():
common_test_reduce(opr=F.var, ref_opr=np.var)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.var(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([2.25, 2.25, 2.25]).astype(np.float32))
def test_std():
common_test_reduce(opr=F.std, ref_opr=np.std)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.std(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32))
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.std(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32))
def test_min():
common_test_reduce(opr=F.min, ref_opr=np.min)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.min(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([1, 4]).astype(np.int32))
def test_max():
common_test_reduce(opr=F.max, ref_opr=np.max)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.max(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([3, 6]).astype(np.int32))
def test_argmin():
common_test_reduce(opr=F.argmin, ref_opr=np.argmin)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.argmin(x, axis=-1)
np.testing.assert_equal(y.numpy(), np.array([0, 0]).astype(np.int32))
def test_argmax():
common_test_reduce(opr=F.argmax, ref_opr=np.argmax)
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.argmax(x, axis=-2)
np.testing.assert_equal(y.numpy(), np.array([1, 1, 1]).astype(np.int32))
def test_norm():
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.norm(x, axis=-1)
np.testing.assert_equal(
y.numpy().round(decimals=3), np.array([3.742, 8.775]).astype(np.float32)
)
def test_sqrt():
......@@ -136,7 +183,7 @@ def test_sort_empty(is_symbolic):
fn_ = fn
data = np.random.random(shape).astype(np.float32)
for _ in range(3):
outs = fn_(tensor(data))
outs = fn_(Tensor(data))
ref_outs = (np.sort(data), np.argsort(data))
assert len(ref_outs) == len(outs)
for i in range(len(outs)):
......@@ -146,6 +193,12 @@ def test_sort_empty(is_symbolic):
def test_normalize():
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.normalize(x, axis=-1)
np.testing.assert_equal(
y.numpy().round(decimals=1),
np.array([[0.3, 0.5, 0.8], [0.5, 0.6, 0.7]]).astype(np.float32),
)
cases = [
{"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2)
......@@ -177,11 +230,11 @@ def test_sum_neg_axis():
shape = (2, 3)
data = np.random.random(shape).astype(np.float32)
for axis in (-1, -2, (-2, 1), (-1, 0)):
get = F.sum(tensor(data), axis=axis)
get = F.sum(Tensor(data), axis=axis)
ref = np.sum(data, axis=axis)
np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6)
with pytest.raises(AssertionError):
F.sum(tensor(data), axis=(-1, 1))
F.sum(Tensor(data), axis=(-1, 1))
def test_builtin_reduce():
......@@ -204,18 +257,18 @@ def test_non_finite():
data = []
for i in range(2):
data.append(np.random.random(shape).astype(np.float32))
tensorList = [tensor(x) for x in data]
tensorList = [Tensor(x) for x in data]
rst = F.math._check_non_finite(tensorList, 0.7)
np.testing.assert_equal(rst.numpy(), [0])
for i in range(len(tensorList)):
np.testing.assert_allclose(tensorList[i].numpy() / 0.7, data[i], rtol=1e-6)
data[1][0][0][0][0] = float("inf")
rst = F.math._check_non_finite([tensor(x) for x in data], 0.7)
rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7)
np.testing.assert_equal(rst.numpy(), [1])
data[1][0][0][0][0] = float("nan")
rst = F.math._check_non_finite([tensor(x) for x in data], 0.7)
rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7)
np.testing.assert_equal(rst.numpy(), [1])
......@@ -237,7 +290,7 @@ def test_topk(descending, sorted, inp1d, kth_only):
return np.sort(x)
res = F.topk(
tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only
Tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only
)
values, indices = res
......@@ -268,7 +321,7 @@ def test_reduce_on_empty_tensor(is_trace):
if is_trace:
fn = jit.trace(symbolic=symbolic)(fn)
for i in range(3):
out = fn(tensor(input, dtype=dtype), axis=axis).numpy()
out = fn(Tensor(input, dtype=dtype), axis=axis).numpy()
out_ref = ref_fn(input.astype(dtype), axis=axis)
np.testing.assert_equal(out, out_ref)
......
......@@ -7,7 +7,7 @@ import pytest
from utils import get_var_value, make_tensor, opr_test
import megengine.functional as F
from megengine import tensor
from megengine import Tensor
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.utils import astensor1d
......@@ -30,7 +30,7 @@ def test_eye():
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(tensor(case["input"]), dtype=dtype).numpy(),
F.eye(Tensor(case["input"]), dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
......@@ -60,7 +60,21 @@ def test_full():
values = [True, 4, 5.0]
for value in values:
np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
assert F.full(shape, value).dtype == tensor(value).dtype
assert F.full(shape, value).dtype == Tensor(value).dtype
@pytest.mark.parametrize("is_varnode", [True, False])
def test_cumsum(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = Tensor([[1, 2, 3], [4, 5, 6]], np.int32)
y = F.cumsum(x, -1)
np.testing.assert_equal(
y.numpy(), np.array([[1, 3, 6], [4, 9, 15]]).astype(np.int32)
)
@pytest.mark.parametrize("is_varnode", [True, False])
......@@ -83,6 +97,14 @@ def test_concat(is_varnode):
cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
x1 = Tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
x2 = Tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
y = F.concat([x1, x2], axis=-1)
np.testing.assert_equal(
y.numpy(),
np.array([[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11]]).astype(np.float32),
)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_condtake(is_varnode):
......@@ -139,6 +161,20 @@ def test_stack(is_varnode):
cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
)
x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
y = F.stack([x1, x2], axis=-1)
np.testing.assert_equal(
y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32)
)
x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
y = F.stack([x1, x2], axis=-1)
np.testing.assert_equal(
y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32)
)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_split_basic(is_varnode):
......@@ -183,6 +219,12 @@ def test_split_basic(is_varnode):
@pytest.mark.parametrize("symbolic", [None, False, True])
def test_split(symbolic):
x = Tensor(np.random.random((10, 20)), dtype=np.float32)
y = F.split(x, 3, axis=-1)
z = F.split(x, [6, 17], axis=-1)
assert str([i.numpy().shape for i in y]) == "[(10, 7), (10, 7), (10, 6)]"
assert str([i.numpy().shape for i in z]) == "[(10, 6), (10, 11), (10, 3)]"
inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
......@@ -208,12 +250,43 @@ def test_split(symbolic):
fn = trace(symbolic=symbolic)(func)
for i in range(3 if symbolic is not None else 1):
ref_out = ref(*case)
out = fn(tensor(case[0]), case[1], case[2])
out = fn(Tensor(case[0]), case[1], case[2])
assert len(ref_out) == len(out)
for idx in range(len(ref_out)):
np.testing.assert_equal(ref_out[idx], out[idx].numpy())
def test_gather():
x = Tensor([[1, 2], [3, 4], [5, 6],])
index = Tensor([[0, 1], [1, 0], [1, 1]])
y = F.gather(x, 1, index)
np.testing.assert_equal(
y.numpy(), np.array([[1, 2], [4, 3], [6, 6]]).astype(np.int32)
)
def test_scatter():
x = Tensor(np.zeros(shape=(3, 5), dtype=np.float32))
source = Tensor(
[
[0.9935, 0.9465, 0.2256, 0.8926, 0.4396],
[0.7723, 0.0718, 0.5939, 0.357, 0.4576],
]
)
index = Tensor([[0, 2, 0, 2, 1], [2, 0, 1, 1, 2]])
y = F.scatter(x, -2, index, source)
np.testing.assert_equal(
y.numpy().round(decimals=4),
np.array(
[
[0.9935, 0.0718, 0.2256, 0.0, 0.0],
[0.0, 0.0, 0.5939, 0.357, 0.4396],
[0.7723, 0.9465, 0.0, 0.8926, 0.4576],
]
).astype(np.float32),
)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_swapaxes(is_varnode):
if is_varnode:
......@@ -221,7 +294,7 @@ def test_swapaxes(is_varnode):
else:
network = None
x = tensor(np.array([[1, 2, 3]], dtype=np.int32))
x = Tensor(np.array([[1, 2, 3]], dtype=np.int32))
y = F.swapaxes(x, 0, 1)
np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32))
......@@ -280,15 +353,15 @@ def test_broadcast_auto_infer(is_varnode):
def test_reshape_on_empty_tensor(is_trace):
input1_shape = (100, 0, 1)
output1_shape = (100, 0, 10)
data1 = tensor(np.random.random(input1_shape).astype(np.float32))
data1 = Tensor(np.random.random(input1_shape).astype(np.float32))
input2_shape = (10, 0)
output2_shape = (0,)
data2 = tensor(np.random.random(input2_shape).astype(np.float32))
data2 = Tensor(np.random.random(input2_shape).astype(np.float32))
input3_shape = (10, 0, 10)
output3_shape = (0, 1, 2, 3)
data3 = tensor(np.random.random(input3_shape).astype(np.float32))
data3 = Tensor(np.random.random(input3_shape).astype(np.float32))
def comp(out, target_shp):
assert out._tuple_shape == target_shp
......@@ -338,7 +411,7 @@ def test_reshape_shape_inference(is_varnode):
def check_shape(output, target):
source = output.shape
if isinstance(source, tensor):
if isinstance(source, Tensor):
source = source.numpy()
np.testing.assert_equal(source, target.shape)
......@@ -366,6 +439,10 @@ def test_squeeze(is_varnode):
else:
network = None
x = Tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
y = F.squeeze(x, -1)
np.testing.assert_equal(y.numpy(), np.array([[[1, 2]]]).astype(np.int32))
x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
xx = make_tensor(x, network)
......@@ -385,6 +462,12 @@ def test_expand_dims(is_varnode):
else:
network = None
x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
y = F.expand_dims(x, -1)
np.testing.assert_equal(
y.numpy(), np.array([[[1], [2], [3]], [[4], [5], [6]]]).astype(np.int32)
)
x = np.arange(6, dtype="float32").reshape(2, 3)
xx = make_tensor(x, network)
......@@ -533,6 +616,22 @@ def test_flatten(is_varnode):
else:
network = None
inp_shape = (2, 2, 3, 3)
x = Tensor(np.arange(36, dtype=np.int32).reshape(inp_shape),)
y = F.flatten(x, -2, -1)
np.testing.assert_equal(
y.numpy(),
np.array(
[
[[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16, 17]],
[
[18, 19, 20, 21, 22, 23, 24, 25, 26],
[27, 28, 29, 30, 31, 32, 33, 34, 35],
],
]
).astype(np.int32),
)
data0_shape = (2, 3, 4, 5)
data1_shape = (4, 5, 6, 7)
data0 = np.random.random(data0_shape).astype(np.float32)
......@@ -616,15 +715,15 @@ def test_broadcast(is_varnode):
def test_broadcast_on_empty_tensor(is_trace):
input1_shape = (100, 0, 1)
output1_shape = (100, 0, 10)
data1 = tensor(np.random.random(input1_shape).astype(np.float32))
data1 = Tensor(np.random.random(input1_shape).astype(np.float32))
input2_shape = (10, 0)
output2_shape = (10, 10, 0)
data2 = tensor(np.random.random(input2_shape).astype(np.float32))
data2 = Tensor(np.random.random(input2_shape).astype(np.float32))
input3_shape = (0, 0, 1, 10)
output3_shape = (10, 0, 0, 10, 10)
data3 = tensor(np.random.random(input3_shape).astype(np.float32))
data3 = Tensor(np.random.random(input3_shape).astype(np.float32))
def comp(out, target_shp):
assert out._tuple_shape == target_shp
......@@ -705,7 +804,7 @@ def test_utils_astensor1d(is_varnode):
def test_device():
x = tensor([1, 2, 3], dtype="float32")
x = Tensor([1, 2, 3], dtype="float32")
y1 = F.eye(x.shape, dtype="float32")
y2 = F.eye(x.shape, dtype="float32", device=None)
......@@ -789,7 +888,7 @@ def test_copy_d2d(is_varnode):
)
@pytest.mark.parametrize("is_symbolic", [None, True, False])
def test_copy_empty(shape, device_src, device_dst, is_symbolic):
inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src)
inp = Tensor(np.random.randn(*shape).astype("float32"), device=device_src)
def func(inp):
return F.copy(inp, device_dst)
......@@ -885,6 +984,12 @@ def test_roll(shape, shifts, axis, is_varnode):
else:
network = None
x = Tensor([[1, 2], [3, 4], [5, 6]], np.int32)
y = F.roll(x, 1, -1)
np.testing.assert_equal(
y.numpy(), np.array([[2, 1], [4, 3], [6, 5]]).astype(np.int32)
)
inp = np.random.randn(*shape).astype("float32")
def func(inp):
......@@ -904,7 +1009,7 @@ def test_roll(shape, shifts, axis, is_varnode):
)
@pytest.mark.parametrize("is_symbolic", [None, True, False])
def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
inp = tensor(np.random.randn(*shape).astype("float32"))
inp = Tensor(np.random.randn(*shape).astype("float32"))
def func(inp):
return F.roll(inp, shifts, axis)
......
#include "../dnn_op_helper.h"
#include "megbrain/imperative/ops/autogen.h"
#include "../op_trait.h"
#include "megbrain/opr/indexing.h"
#include "megdnn/oprs/general.h"
namespace mgb {
namespace imperative {
......@@ -12,10 +14,8 @@ namespace indexing_one_hot {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
auto& op = def.cast_final_safe<IndexingOneHot>();
auto&& op = def.cast_final_safe<IndexingOneHot>();
mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs");
auto comp_node = input_descs[0].comp_node;
TensorLayout src = input_descs[0].layout, index = input_descs[1].layout;
......@@ -28,10 +28,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert(src.ndim >= 2, "src ndim must be at least 2");
mgb_assert(src.is_contiguous(), "src should be contiguous");
mgb_assert(
op.axis >= 0 && op.axis < src.ndim, "axis %d not exists in src", op.axis);
-static_cast<int>(src.ndim) <= op.axis &&
op.axis < static_cast<int>(src.ndim),
"axis %d not exists in src", op.axis);
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(src.ndim);
}
TensorLayout dst = src;
dst.shape[op.axis] = 1;
dst.shape[real_axis] = 1;
dst.init_contiguous_stride();
if (!index.ndim) {
......@@ -40,24 +45,128 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert(index.is_contiguous(), "index should be all contiguous");
mgb_assert(
index.eq_shape(src.remove_axis(op.axis)), "index shape doesn't match src");
index.eq_shape(src.remove_axis(real_axis)),
"index shape doesn't match src");
return {{{dst, comp_node}}, true};
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def);
auto&& op = def.cast_final_safe<IndexingOneHot>();
mgb_assert(inputs.size() == 2);
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(op.ndim);
}
OperatorNodeConfig config{op.make_name()};
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
return opr::IndexingOneHot::make(inputs[0], inputs[1], real_axis, config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<IndexingOneHot>();
auto&& inp = inputs[0];
auto&& index = inputs[1];
TensorLayout layout = inp->layout();
TensorLayout index_layout = index->layout();
DnnOprCaller<megdnn::IndexingOneHot> dnn_op(inp->comp_node());
auto&& indexing_one_hot_param = dnn_op.op->param();
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(layout.ndim);
}
mgb_assert(
0 <= real_axis && real_axis < static_cast<int>(layout.ndim),
"Dimension out of range (expected to be in range of [%d, %d], but got %d)",
0, static_cast<int>(layout.ndim) - 1, op.axis);
indexing_one_hot_param = real_axis;
TensorLayout tlayout;
dnn_op.op->deduce_layout(layout, index_layout, tlayout);
TensorPtr out = Tensor::make(tlayout, inp->comp_node());
megdnn::TensorND in = inp->dnn_tensor();
megdnn::TensorND ind = index->dnn_tensor();
TensorLayout m_layout(
{dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)},
dtype::Byte());
auto dnn_workspace = dnn_op.create_workspace(m_layout);
dnn_op.op->exec(in, ind, out->dnn_tensor(), dnn_workspace);
return {out};
}
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace indexing_one_hot
namespace indexing_set_one_hot {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
mgb_assert(input_descs.size() == 3, "IndexingSetOneHot expects three inputs");
auto comp_node = input_descs[0].comp_node;
TensorLayout src = input_descs[0].layout, index = input_descs[1].layout;
mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32");
if (!src.ndim) {
return {{{{{}, src.dtype}, comp_node}}, false};
}
mgb_assert(src.is_contiguous(), "src should be contiguous");
return {{input_descs[0]}, true};
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def);
mgb_assert(inputs.size() == 3);
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(op.ndim);
}
OperatorNodeConfig config{op.make_name()};
return opr::IndexingSetOneHot::make(
inputs[0], inputs[1], inputs[2], real_axis, config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<IndexingSetOneHot>();
auto&& inp = inputs[0];
auto&& index = inputs[1];
auto&& sub = inputs[2];
TensorLayout layout = inp->layout();
TensorLayout index_layout = index->layout();
TensorLayout tlayout = sub->layout();
mgb_assert(layout.is_contiguous());
DnnOprCaller<megdnn::IndexingSetOneHot> dnn_op(inp->comp_node());
auto&& indexing_one_hot_param = dnn_op.op->param();
int real_axis = static_cast<int>(op.axis);
if (real_axis < 0) {
real_axis += static_cast<int>(layout.ndim);
}
indexing_one_hot_param = real_axis;
TensorPtr out = Tensor::make(layout, inp->comp_node());
out->dev_tensor().copy_from_fixlayout(inp->dev_tensor());
megdnn::TensorND in = inp->dnn_tensor();
megdnn::TensorND ind = index->dnn_tensor();
megdnn::TensorND su = sub->dnn_tensor();
TensorLayout m_layout(
{dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)},
dtype::Byte());
auto dnn_workspace = dnn_op.create_workspace(m_layout);
dnn_op.op->exec(out->dnn_tensor(), ind, su, dnn_workspace);
return {out};
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace indexing_set_one_hot
} // anonymous namespace
} // namespace imperative
} // namespace mgb
......
......@@ -372,21 +372,6 @@ OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallba
} // namespace group_local
} // namespace
namespace {
namespace indexing_set_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def);
mgb_assert(inputs.size() == 3);
OperatorNodeConfig config{op.make_name()};
return opr::IndexingSetOneHot::make(
inputs[0], inputs[1], inputs[2], op.param(), config);
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace indexing_set_one_hot
} // namespace
namespace {
namespace typecvt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
......
......@@ -108,9 +108,17 @@ def Remap: MgbHashableOp<"Remap", [RemapParam]>;
def Resize: MgbHashableOp<"Resize", [ResizeParam]>;
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>;
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]> {
let extraArguments = (ins
MgbI32Attr:$ndim
);
}
def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>;
def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]> {
let extraArguments = (ins
MgbI32Attr:$ndim
);
}
def Copy: MgbHashableOp<"Copy"> {
let extraArguments = (ins
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册