Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
24c5c19b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
24c5c19b
编写于
4月 22, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative): make functional ops support negative axis
GitOrigin-RevId: f61e01270b948ab5bd6ba32b091d7c6b8d7a0745
上级
c76e80bc
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
469 addition
and
96 deletion
+469
-96
dnn/include/megdnn/oprs/general.h
dnn/include/megdnn/oprs/general.h
+1
-1
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+2
-2
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+5
-24
imperative/python/src/grad_override.cpp
imperative/python/src/grad_override.cpp
+82
-0
imperative/python/test/unit/core/test_autodiff.py
imperative/python/test/unit/core/test_autodiff.py
+47
-0
imperative/python/test/unit/core/test_function.py
imperative/python/test/unit/core/test_function.py
+1
-3
imperative/python/test/unit/functional/test_loss.py
imperative/python/test/unit/functional/test_loss.py
+19
-14
imperative/python/test/unit/functional/test_math.py
imperative/python/test/unit/functional/test_math.py
+62
-9
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+120
-15
imperative/src/impl/ops/indexing.cpp
imperative/src/impl/ops/indexing.cpp
+120
-11
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+0
-15
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+10
-2
未找到文件。
dnn/include/megdnn/oprs/general.h
浏览文件 @
24c5c19b
...
...
@@ -1015,7 +1015,7 @@ class IndexingOneHotBase : public OperatorBase {
DEF_OPR_PARAM
(
Axis
);
protected:
void
deduce_layout_fwd
(
MGE_WIN_DECLSPEC_FUC
void
deduce_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
index
,
TensorLayout
&
dst
);
void
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
index
,
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
24c5c19b
...
...
@@ -1558,7 +1558,7 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
)
ones_tensor
=
ones
(
list
(
inp
.
shape
)
+
[
1
],
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
op
=
builtin
.
IndexingSetOneHot
(
axis
=
inp
.
ndim
)
op
=
builtin
.
IndexingSetOneHot
(
axis
=
inp
.
ndim
,
ndim
=
inp
.
ndim
)
(
result
,)
=
apply
(
op
,
zeros_tensor
,
inp
,
ones_tensor
)
return
result
...
...
@@ -1609,7 +1609,7 @@ def indexing_one_hot(
array([1.], dtype=float32)
"""
assert
isinstance
(
src
,
Tensor
),
"src must be of Tensor type"
op
=
builtin
.
IndexingOneHot
(
axis
=
axis
)
op
=
builtin
.
IndexingOneHot
(
axis
=
axis
,
ndim
=
src
.
ndim
)
index
=
convert_single_value
(
index
,
dtype
=
"int32"
,
device
=
src
.
device
)
(
result
,)
=
apply
(
op
,
src
,
index
)
if
not
keepdims
:
...
...
imperative/python/megengine/functional/tensor.py
浏览文件 @
24c5c19b
...
...
@@ -393,6 +393,8 @@ def split(inp, nsplits_or_sections, axis=0):
def
_get_idx
(
index
,
axis
):
index_dims
=
len
(
index
.
shape
)
idx
=
[]
if
axis
<
0
:
axis
+=
index_dims
for
i
in
range
(
index_dims
):
if
i
!=
axis
:
shape
=
[
1
]
*
index_dims
...
...
@@ -457,21 +459,6 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
"But the input dims:{}, the index dims:{}"
.
format
(
input_dims
,
index_dims
)
)
if
axis
<
0
or
axis
>=
input_dims
:
raise
ValueError
(
"Index axis {} is output of bounds, should in range [0 {})"
.
format
(
axis
,
input_dims
)
)
for
i
in
range
(
input_dims
):
if
i
!=
axis
and
input_shape
[
i
]
!=
index_shape
[
i
]:
raise
ValueError
(
"The input {} and index {} must have the same size apart from axis {}"
.
format
(
input_shape
,
index_shape
,
axis
)
)
idx
=
_get_idx
(
index
,
axis
)
return
inp
[
idx
].
reshape
(
index
.
shape
)
# pylint: disable=no-member
...
...
@@ -524,7 +511,7 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
>>> inp = Tensor(np.zeros(shape=(3,5),dtype=np.float32))
>>> source = Tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
>>> index = Tensor([[0,2,0,2,1],[2,0,1,1,2]])
>>> oup = F.scatter(inp, 0, index,source)
>>> oup = F.scatter(inp, 0, index,
source)
>>> oup.numpy()
array([[0.9935, 0.0718, 0.2256, 0. , 0. ],
[0. , 0. , 0.5939, 0.357 , 0.4396],
...
...
@@ -540,13 +527,6 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
if
input_dims
!=
index_dims
or
input_dims
!=
source_dims
:
raise
ValueError
(
"The input, source and index tensor must have same dimensions"
)
if
axis
<
0
or
axis
>=
input_dims
:
raise
ValueError
(
"Index axis {} is output of bounds, should in range [0 {})"
.
format
(
axis
,
input_dims
)
)
for
i
in
range
(
source_dims
):
if
source_shape
[
i
]
>
input_shape
[
i
]:
raise
ValueError
(
...
...
@@ -792,6 +772,8 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
>>> out.numpy().shape
(2, 2, 9)
"""
if
start_axis
<
0
:
start_axis
+=
len
(
inp
.
shape
)
target_shape
=
tuple
(
inp
.
shape
[
i
]
for
i
in
range
(
start_axis
))
+
(
-
1
,)
if
end_axis
!=
-
1
:
target_shape
+=
(
*
inp
.
shape
[
end_axis
+
1
:],)
...
...
@@ -1158,6 +1140,5 @@ def cumsum(inp: Tensor, axis: int):
[ 4 9 15]], dtype=int32, device=xpux:0)
"""
assert
isinstance
(
inp
,
Tensor
),
"input of cumsum must be type of Tensor"
assert
axis
>=
0
and
axis
<
inp
.
ndim
,
"input axis {} out of bound"
.
format
(
axis
)
op
=
builtin
.
Cumsum
(
axis
=
axis
,
exclusive
=
False
,
reverse
=
False
)
return
apply
(
op
,
inp
)[
0
]
imperative/python/src/grad_override.cpp
浏览文件 @
24c5c19b
...
...
@@ -490,6 +490,84 @@ std::optional<ValueRefList> pixelShuffle_grad_rule(
return
imperative
::
apply
(
op
,
inputs
);
}
std
::
optional
<
ValueRefList
>
indexing_grad_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_require_grad
,
CustomBackward
&
backward
)
{
auto
&&
indexing
=
op
.
cast_final_safe
<
IndexingOneHot
>
();
mgb_assert
(
inputs
.
size
()
==
2
);
bool
flag
=
inputs_require_grad
[
0
];
auto
&&
grad_op
=
IndexingSetOneHot
::
make
(
indexing
.
axis
,
indexing
.
ndim
);
SmallVector
<
ValueRef
>
inputs2
;
if
(
flag
)
{
inputs2
.
push_back
(
get_shape
(
inputs
[
0
]));
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
++
i
)
{
inputs2
.
push_back
(
inputs
[
i
]);
}
}
auto
maker
=
CustomGradMaker
(
backward
,
inputs
.
size
());
maker
.
output_size
(
1
).
output_captured
(
0
,
false
);
maker
.
backward
([
inputs
=
std
::
move
(
inputs2
),
grad_op_
=
std
::
move
(
grad_op
)](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
SmallVector
<
ValueRef
>
ret
(
1
);
if
(
grad
&&
inputs
[
0
])
{
ValueRefList
args_
(
inputs
.
size
()
+
1
);
auto
&&
zeros
=
make_empty_tensor
(
grad
.
device
(),
inputs
[
0
],
grad
.
dtype
());
args_
[
0
]
=
zeros
;
args_
[
1
]
=
inputs
[
1
];
args_
[
2
]
=
grads
[
0
];
ret
[
0
]
=
imperative
::
apply
(
*
grad_op_
,
args_
)[
0
];
}
return
ret
;
});
maker
.
finalize
();
return
imperative
::
apply
(
op
,
inputs
);
}
std
::
optional
<
ValueRefList
>
indexing_set_one_hot_grad_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_require_grad
,
CustomBackward
&
backward
)
{
auto
&&
indexingSetOneHot
=
op
.
cast_final_safe
<
IndexingSetOneHot
>
();
mgb_assert
(
inputs
.
size
()
==
3
);
SmallVector
<
ValueRef
>
inputs2
;
inputs2
.
push_back
(
get_shape
(
inputs
[
0
]));
inputs2
.
push_back
(
inputs
[
1
]);
inputs2
.
push_back
(
get_shape
(
inputs
[
2
]));
auto
maker
=
CustomGradMaker
(
backward
,
inputs
.
size
());
maker
.
output_size
(
1
).
output_captured
(
0
,
false
);
maker
.
backward
([
inputs
=
std
::
move
(
inputs2
),
&
indexingSetOneHot
](
Span
<
ValueRef
>
grads
)
{
mgb_assert
(
grads
.
size
()
==
1
);
ValueRef
grad
=
grads
[
0
];
SmallVector
<
ValueRef
>
ret
(
3
);
if
(
!
grad
)
{
return
ret
;
}
if
(
inputs
[
0
])
{
auto
&&
grad_op
=
IndexingSetOneHot
::
make
(
indexingSetOneHot
.
axis
,
indexingSetOneHot
.
ndim
);
ValueRefList
args_
(
inputs
.
size
());
auto
&&
zeros
=
make_empty_tensor
(
grad
.
device
(),
inputs
[
2
],
grad
.
dtype
());
args_
[
0
]
=
grads
[
0
];
args_
[
1
]
=
inputs
[
1
];
args_
[
2
]
=
zeros
;
ret
[
0
]
=
imperative
::
apply
(
*
grad_op
,
args_
)[
0
];
}
if
(
inputs
[
2
])
{
auto
&&
grad_op
=
IndexingOneHot
::
make
(
indexingSetOneHot
.
axis
,
indexingSetOneHot
.
ndim
);
ValueRefList
args_
(
inputs
.
size
()
-
1
);
args_
[
0
]
=
grads
[
0
];
args_
[
1
]
=
inputs
[
1
];
ret
[
2
]
=
imperative
::
apply
(
*
grad_op
,
args_
)[
0
];
}
return
ret
;
});
maker
.
finalize
();
return
imperative
::
apply
(
op
,
inputs
);
}
std
::
optional
<
ValueRefList
>
fastpathcopy_grad_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
inputs_require_grad
,
CustomBackward
&
backward
)
{
...
...
@@ -521,6 +599,10 @@ struct Init {
CustomBackward
::
register_grad_rule
(
AddAxis
::
typeinfo
(),
addAxis_grad_rule
);
CustomBackward
::
register_grad_rule
(
RemoveAxis
::
typeinfo
(),
removeAxis_grad_rule
);
CustomBackward
::
register_grad_rule
(
IndexingOneHot
::
typeinfo
(),
indexing_grad_rule
);
CustomBackward
::
register_grad_rule
(
IndexingSetOneHot
::
typeinfo
(),
indexing_set_one_hot_grad_rule
);
CustomBackward
::
register_grad_rule
(
FastpathCopy
::
typeinfo
(),
fastpathcopy_grad_rule
);
CustomBackward
::
register_grad_rule
(
...
...
imperative/python/test/unit/core/test_autodiff.py
浏览文件 @
24c5c19b
...
...
@@ -8,11 +8,15 @@ import megengine as mge
import
megengine.distributed
as
dist
import
megengine.functional
as
F
import
megengine.module
as
M
from
megengine
import
Tensor
from
megengine.core
import
_imperative_rt
from
megengine.core._imperative_rt
import
CompNode
,
TensorAttr
,
imperative
from
megengine.core._imperative_rt.core2
import
TensorWeakRef
,
apply
,
sync
from
megengine.core.autodiff.grad
import
Grad
from
megengine.core.ops
import
builtin
from
megengine.core.ops.builtin
import
Elemwise
,
Identity
from
megengine.functional.distributed
import
remote_recv
,
remote_send
from
megengine.functional.tensor
import
ones
,
zeros
def
_elwise
(
mode
):
...
...
@@ -553,3 +557,46 @@ def test_matmul():
if
ydim
==
1
and
transposeB
==
True
:
continue
test_one
(
xdim
,
ydim
,
transposeA
,
transposeB
)
def
test_indexing
():
x
=
np
.
array
([[
1.0
,
2.0
]]).
astype
(
"float32"
)
x
=
mge
.
Tensor
(
x
)
index
=
mge
.
Tensor
([
0
])
with
Grad
()
as
grad
:
grad
.
wrt
(
x
,
callback
=
save_to
(
x
))
def
f
(
x
):
return
F
.
indexing_one_hot
(
x
,
index
,
-
1
)
y
=
f
(
x
)
grad
(
y
,
F
.
ones_like
(
y
))
np
.
testing
.
assert_equal
(
np
.
array
([[
1
,
0
]],
dtype
=
np
.
float32
),
x
.
grad
.
numpy
())
def
test_indexing_set_one_hot
():
x
=
mge
.
tensor
(
np
.
arange
(
1
,
4
,
dtype
=
np
.
int32
))
with
Grad
()
as
grad
:
zeros_tensor
=
zeros
((
3
,
4
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
ones_tensor
=
ones
((
3
,
1
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
grad
.
wrt
(
zeros_tensor
,
callback
=
save_to
(
zeros_tensor
))
grad
.
wrt
(
ones_tensor
,
callback
=
save_to
(
ones_tensor
))
def
f
(
x
):
op
=
builtin
.
IndexingSetOneHot
(
axis
=
x
.
ndim
,
ndim
=
x
.
ndim
)
(
result
,)
=
apply
(
op
,
zeros_tensor
,
x
,
ones_tensor
)
return
result
y
=
f
(
x
)
grad
(
y
,
F
.
ones_like
(
y
))
np
.
testing
.
assert_equal
(
np
.
array
([[
1
,
0
,
1
,
1
],
[
1
,
1
,
0
,
1
],
[
1
,
1
,
1
,
0
]],
dtype
=
np
.
int32
),
zeros_tensor
.
grad
.
numpy
(),
)
np
.
testing
.
assert_equal
(
np
.
array
([[
1
],
[
1
],
[
1
]],
dtype
=
np
.
int32
),
ones_tensor
.
grad
.
numpy
(),
)
imperative/python/test/unit/core/test_function.py
浏览文件 @
24c5c19b
...
...
@@ -6,9 +6,7 @@ import pytest
import
megengine.autodiff
as
ad
import
megengine.functional
as
F
import
megengine.optimizer
as
optimizer
from
megengine
import
Parameter
from
megengine
import
Tensor
as
tensor
from
megengine
import
tensor
from
megengine
import
Parameter
,
Tensor
,
tensor
from
megengine.autodiff
import
Function
from
megengine.module
import
Module
...
...
imperative/python/test/unit/functional/test_loss.py
浏览文件 @
24c5c19b
...
...
@@ -3,15 +3,15 @@ import numpy as np
import
pytest
import
megengine.functional
as
F
from
megengine
import
t
ensor
import
megengine.tensor
as
T
ensor
def
test_cross_entropy_with_logits
():
data
=
t
ensor
([[
0
,
50
],
[
0
,
-
150
]]).
astype
(
np
.
float32
)
label
=
t
ensor
([
1
,
0
]).
astype
(
np
.
int32
)
data
=
T
ensor
([[
0
,
50
],
[
0
,
-
150
]]).
astype
(
np
.
float32
)
label
=
T
ensor
([
1
,
0
]).
astype
(
np
.
int32
)
loss
=
F
.
nn
.
cross_entropy
(
data
,
label
)
np
.
testing
.
assert_allclose
(
loss
.
numpy
(),
0.0
)
label
=
t
ensor
([
0
,
1
]).
astype
(
np
.
int32
)
label
=
T
ensor
([
0
,
1
]).
astype
(
np
.
int32
)
loss
=
F
.
nn
.
cross_entropy
(
data
,
label
)
np
.
testing
.
assert_allclose
(
loss
.
numpy
(),
100
)
...
...
@@ -35,19 +35,24 @@ def test_cross_entropy():
x
[
i
,
y
[
i
]]
+=
np
.
random
.
rand
()
*
2
x
=
softmax
(
x
)
l_ref
=
ref
(
x
,
y
)
l
=
F
.
nn
.
cross_entropy
(
tensor
(
x
,
"float32"
),
t
ensor
(
y
,
"int32"
),
with_logits
=
False
)
l
=
F
.
nn
.
cross_entropy
(
Tensor
(
x
,
"float32"
),
T
ensor
(
y
,
"int32"
),
with_logits
=
False
)
np
.
testing
.
assert_allclose
(
l
.
numpy
(),
l_ref
,
1e-6
,
1e-6
)
l1
=
F
.
nn
.
cross_entropy
(
Tensor
(
x
,
"float32"
),
Tensor
(
y
,
"int32"
),
axis
=-
1
,
with_logits
=
False
)
np
.
testing
.
assert_allclose
(
l1
.
numpy
(),
l_ref
,
1e-6
,
1e-6
)
def
test_cross_entropy_reduction
():
logits
=
np
.
random
.
randn
(
16
,
10
)
label
=
np
.
random
.
randint
(
10
,
size
=
[
16
])
logits
=
t
ensor
(
logits
,
dtype
=
"float32"
)
label
=
t
ensor
(
label
,
dtype
=
"int32"
)
logits
=
T
ensor
(
logits
,
dtype
=
"float32"
)
label
=
T
ensor
(
label
,
dtype
=
"int32"
)
perm
=
np
.
random
.
permutation
(
16
)
logits_perm
=
t
ensor
(
logits
[
perm
],
dtype
=
"float32"
)
label_perm
=
t
ensor
(
label
[
perm
],
dtype
=
"int32"
)
logits_perm
=
T
ensor
(
logits
[
perm
],
dtype
=
"float32"
)
label_perm
=
T
ensor
(
label
[
perm
],
dtype
=
"int32"
)
loss
=
F
.
nn
.
cross_entropy
(
logits
,
label
,
reduction
=
"none"
)
loss_perm
=
F
.
nn
.
cross_entropy
(
logits_perm
,
label_perm
,
reduction
=
"none"
)
...
...
@@ -160,18 +165,18 @@ def _ctc_npy_single_seq(pred, label, blank):
def
test_ctc_loss
():
def
test_func
(
T
,
C
,
N
):
input
=
np
.
random
.
randn
(
T
,
N
,
C
)
input
=
F
.
softmax
(
t
ensor
(
input
),
axis
=-
1
).
numpy
()
input
=
F
.
softmax
(
T
ensor
(
input
),
axis
=-
1
).
numpy
()
input_lengths
=
np
.
ones
(
N
,
dtype
=
np
.
int32
)
*
T
target_lengths
=
np
.
random
.
randint
(
low
=
1
,
high
=
T
+
1
,
size
=
(
N
,),
dtype
=
np
.
int32
)
target
=
np
.
random
.
randint
(
low
=
1
,
high
=
C
,
size
=
(
sum
(
target_lengths
)),
dtype
=
np
.
int32
)
input_mge
=
t
ensor
(
input
)
input_lengths_mge
=
t
ensor
(
input_lengths
)
input_mge
=
T
ensor
(
input
)
input_lengths_mge
=
T
ensor
(
input_lengths
)
target_mge
=
t
ensor
(
target
)
target_lengths_mge
=
t
ensor
(
target_lengths
)
target_mge
=
T
ensor
(
target
)
target_lengths_mge
=
T
ensor
(
target_lengths
)
blank
=
np
.
random
.
randint
(
C
)
for
method
in
[
"mean"
,
"sum"
,
"none"
]:
...
...
imperative/python/test/unit/functional/test_math.py
浏览文件 @
24c5c19b
...
...
@@ -6,7 +6,7 @@ import pytest
from
utils
import
opr_test
import
megengine.functional
as
F
from
megengine
import
jit
,
tensor
from
megengine
import
Tensor
,
jit
,
tensor
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core.ops
import
builtin
...
...
@@ -61,37 +61,84 @@ def common_test_reduce(opr, ref_opr):
def
test_sum
():
common_test_reduce
(
opr
=
F
.
sum
,
ref_opr
=
np
.
sum
)
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
sum
(
x
,
axis
=-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([
6
,
15
]).
astype
(
np
.
int32
))
def
test_prod
():
common_test_reduce
(
opr
=
F
.
prod
,
ref_opr
=
np
.
prod
)
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
prod
(
x
,
axis
=-
2
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([
4
,
10
,
18
]).
astype
(
np
.
int32
))
def
test_mean
():
common_test_reduce
(
opr
=
F
.
mean
,
ref_opr
=
np
.
mean
)
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
mean
(
x
,
axis
=-
2
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([
2.5
,
3.5
,
4.5
]).
astype
(
np
.
float32
))
def
test_var
():
common_test_reduce
(
opr
=
F
.
var
,
ref_opr
=
np
.
var
)
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
var
(
x
,
axis
=-
2
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([
2.25
,
2.25
,
2.25
]).
astype
(
np
.
float32
))
def
test_std
():
common_test_reduce
(
opr
=
F
.
std
,
ref_opr
=
np
.
std
)
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
std
(
x
,
axis
=-
2
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([
1.5
,
1.5
,
1.5
]).
astype
(
np
.
float32
))
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
std
(
x
,
axis
=-
2
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([
1.5
,
1.5
,
1.5
]).
astype
(
np
.
float32
))
def
test_min
():
common_test_reduce
(
opr
=
F
.
min
,
ref_opr
=
np
.
min
)
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
min
(
x
,
axis
=-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([
1
,
4
]).
astype
(
np
.
int32
))
def
test_max
():
common_test_reduce
(
opr
=
F
.
max
,
ref_opr
=
np
.
max
)
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
max
(
x
,
axis
=-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([
3
,
6
]).
astype
(
np
.
int32
))
def
test_argmin
():
common_test_reduce
(
opr
=
F
.
argmin
,
ref_opr
=
np
.
argmin
)
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
argmin
(
x
,
axis
=-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([
0
,
0
]).
astype
(
np
.
int32
))
def
test_argmax
():
common_test_reduce
(
opr
=
F
.
argmax
,
ref_opr
=
np
.
argmax
)
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
argmax
(
x
,
axis
=-
2
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([
1
,
1
,
1
]).
astype
(
np
.
int32
))
def
test_norm
():
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
norm
(
x
,
axis
=-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
().
round
(
decimals
=
3
),
np
.
array
([
3.742
,
8.775
]).
astype
(
np
.
float32
)
)
def
test_sqrt
():
...
...
@@ -136,7 +183,7 @@ def test_sort_empty(is_symbolic):
fn_
=
fn
data
=
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
)
for
_
in
range
(
3
):
outs
=
fn_
(
t
ensor
(
data
))
outs
=
fn_
(
T
ensor
(
data
))
ref_outs
=
(
np
.
sort
(
data
),
np
.
argsort
(
data
))
assert
len
(
ref_outs
)
==
len
(
outs
)
for
i
in
range
(
len
(
outs
)):
...
...
@@ -146,6 +193,12 @@ def test_sort_empty(is_symbolic):
def
test_normalize
():
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
normalize
(
x
,
axis
=-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
().
round
(
decimals
=
1
),
np
.
array
([[
0.3
,
0.5
,
0.8
],
[
0.5
,
0.6
,
0.7
]]).
astype
(
np
.
float32
),
)
cases
=
[
{
"input"
:
np
.
random
.
random
((
2
,
3
,
12
,
12
)).
astype
(
np
.
float32
)}
for
i
in
range
(
2
)
...
...
@@ -177,11 +230,11 @@ def test_sum_neg_axis():
shape
=
(
2
,
3
)
data
=
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
)
for
axis
in
(
-
1
,
-
2
,
(
-
2
,
1
),
(
-
1
,
0
)):
get
=
F
.
sum
(
t
ensor
(
data
),
axis
=
axis
)
get
=
F
.
sum
(
T
ensor
(
data
),
axis
=
axis
)
ref
=
np
.
sum
(
data
,
axis
=
axis
)
np
.
testing
.
assert_allclose
(
get
.
numpy
(),
ref
,
rtol
=
1e-6
)
with
pytest
.
raises
(
AssertionError
):
F
.
sum
(
t
ensor
(
data
),
axis
=
(
-
1
,
1
))
F
.
sum
(
T
ensor
(
data
),
axis
=
(
-
1
,
1
))
def
test_builtin_reduce
():
...
...
@@ -204,18 +257,18 @@ def test_non_finite():
data
=
[]
for
i
in
range
(
2
):
data
.
append
(
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
))
tensorList
=
[
t
ensor
(
x
)
for
x
in
data
]
tensorList
=
[
T
ensor
(
x
)
for
x
in
data
]
rst
=
F
.
math
.
_check_non_finite
(
tensorList
,
0.7
)
np
.
testing
.
assert_equal
(
rst
.
numpy
(),
[
0
])
for
i
in
range
(
len
(
tensorList
)):
np
.
testing
.
assert_allclose
(
tensorList
[
i
].
numpy
()
/
0.7
,
data
[
i
],
rtol
=
1e-6
)
data
[
1
][
0
][
0
][
0
][
0
]
=
float
(
"inf"
)
rst
=
F
.
math
.
_check_non_finite
([
t
ensor
(
x
)
for
x
in
data
],
0.7
)
rst
=
F
.
math
.
_check_non_finite
([
T
ensor
(
x
)
for
x
in
data
],
0.7
)
np
.
testing
.
assert_equal
(
rst
.
numpy
(),
[
1
])
data
[
1
][
0
][
0
][
0
][
0
]
=
float
(
"nan"
)
rst
=
F
.
math
.
_check_non_finite
([
t
ensor
(
x
)
for
x
in
data
],
0.7
)
rst
=
F
.
math
.
_check_non_finite
([
T
ensor
(
x
)
for
x
in
data
],
0.7
)
np
.
testing
.
assert_equal
(
rst
.
numpy
(),
[
1
])
...
...
@@ -237,7 +290,7 @@ def test_topk(descending, sorted, inp1d, kth_only):
return
np
.
sort
(
x
)
res
=
F
.
topk
(
t
ensor
(
data
),
k
,
descending
=
descending
,
no_sort
=
(
not
sorted
),
kth_only
=
kth_only
T
ensor
(
data
),
k
,
descending
=
descending
,
no_sort
=
(
not
sorted
),
kth_only
=
kth_only
)
values
,
indices
=
res
...
...
@@ -268,7 +321,7 @@ def test_reduce_on_empty_tensor(is_trace):
if
is_trace
:
fn
=
jit
.
trace
(
symbolic
=
symbolic
)(
fn
)
for
i
in
range
(
3
):
out
=
fn
(
t
ensor
(
input
,
dtype
=
dtype
),
axis
=
axis
).
numpy
()
out
=
fn
(
T
ensor
(
input
,
dtype
=
dtype
),
axis
=
axis
).
numpy
()
out_ref
=
ref_fn
(
input
.
astype
(
dtype
),
axis
=
axis
)
np
.
testing
.
assert_equal
(
out
,
out_ref
)
...
...
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
24c5c19b
...
...
@@ -7,7 +7,7 @@ import pytest
from
utils
import
get_var_value
,
make_tensor
,
opr_test
import
megengine.functional
as
F
from
megengine
import
t
ensor
from
megengine
import
T
ensor
from
megengine.core._trace_option
import
use_symbolic_shape
from
megengine.core.tensor
import
megbrain_graph
as
G
from
megengine.core.tensor.utils
import
astensor1d
...
...
@@ -30,7 +30,7 @@ def test_eye():
np
.
eye
(
*
case
[
"input"
]).
astype
(
dtype
),
)
np
.
testing
.
assert_allclose
(
F
.
eye
(
t
ensor
(
case
[
"input"
]),
dtype
=
dtype
).
numpy
(),
F
.
eye
(
T
ensor
(
case
[
"input"
]),
dtype
=
dtype
).
numpy
(),
np
.
eye
(
*
case
[
"input"
]).
astype
(
dtype
),
)
...
...
@@ -60,7 +60,21 @@ def test_full():
values
=
[
True
,
4
,
5.0
]
for
value
in
values
:
np
.
testing
.
assert_allclose
(
F
.
full
(
shape
,
value
).
numpy
(),
np
.
full
(
shape
,
value
))
assert
F
.
full
(
shape
,
value
).
dtype
==
tensor
(
value
).
dtype
assert
F
.
full
(
shape
,
value
).
dtype
==
Tensor
(
value
).
dtype
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_cumsum
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
x
=
Tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
np
.
int32
)
y
=
F
.
cumsum
(
x
,
-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([[
1
,
3
,
6
],
[
4
,
9
,
15
]]).
astype
(
np
.
int32
)
)
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
...
...
@@ -83,6 +97,14 @@ def test_concat(is_varnode):
cases
=
[{
"input"
:
[
data1
,
data2
]},
{
"input"
:
[
data1
,
data3
]}]
opr_test
(
cases
,
run
,
ref_fn
=
lambda
x
,
y
:
np
.
concatenate
([
x
,
y
]),
network
=
network
)
x1
=
Tensor
(
np
.
arange
(
0
,
6
,
dtype
=
np
.
float32
).
reshape
((
2
,
3
)))
x2
=
Tensor
(
np
.
arange
(
6
,
12
,
dtype
=
np
.
float32
).
reshape
((
2
,
3
)))
y
=
F
.
concat
([
x1
,
x2
],
axis
=-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([[
0
,
1
,
2
,
6
,
7
,
8
],
[
3
,
4
,
5
,
9
,
10
,
11
]]).
astype
(
np
.
float32
),
)
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_condtake
(
is_varnode
):
...
...
@@ -139,6 +161,20 @@ def test_stack(is_varnode):
cases
,
run
,
ref_fn
=
lambda
x
,
y
:
np
.
stack
([
x
,
y
],
axis
=
ai
),
network
=
network
)
x1
=
Tensor
(
np
.
arange
(
0
,
3
,
dtype
=
np
.
float32
).
reshape
((
3
)))
x2
=
Tensor
(
np
.
arange
(
6
,
9
,
dtype
=
np
.
float32
).
reshape
((
3
)))
y
=
F
.
stack
([
x1
,
x2
],
axis
=-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([[
0
,
6
],
[
1
,
7
],
[
2
,
8
]]).
astype
(
np
.
float32
)
)
x1
=
Tensor
(
np
.
arange
(
0
,
3
,
dtype
=
np
.
float32
).
reshape
((
3
)))
x2
=
Tensor
(
np
.
arange
(
6
,
9
,
dtype
=
np
.
float32
).
reshape
((
3
)))
y
=
F
.
stack
([
x1
,
x2
],
axis
=-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([[
0
,
6
],
[
1
,
7
],
[
2
,
8
]]).
astype
(
np
.
float32
)
)
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_split_basic
(
is_varnode
):
...
...
@@ -183,6 +219,12 @@ def test_split_basic(is_varnode):
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
None
,
False
,
True
])
def
test_split
(
symbolic
):
x
=
Tensor
(
np
.
random
.
random
((
10
,
20
)),
dtype
=
np
.
float32
)
y
=
F
.
split
(
x
,
3
,
axis
=-
1
)
z
=
F
.
split
(
x
,
[
6
,
17
],
axis
=-
1
)
assert
str
([
i
.
numpy
().
shape
for
i
in
y
])
==
"[(10, 7), (10, 7), (10, 6)]"
assert
str
([
i
.
numpy
().
shape
for
i
in
z
])
==
"[(10, 6), (10, 11), (10, 3)]"
inp1
=
np
.
random
.
random
((
3
,
4
,
5
,
6
)).
astype
(
np
.
float32
)
inp2
=
np
.
random
.
random
((
0
,
4
,
5
,
6
)).
astype
(
np
.
float32
)
...
...
@@ -208,12 +250,43 @@ def test_split(symbolic):
fn
=
trace
(
symbolic
=
symbolic
)(
func
)
for
i
in
range
(
3
if
symbolic
is
not
None
else
1
):
ref_out
=
ref
(
*
case
)
out
=
fn
(
t
ensor
(
case
[
0
]),
case
[
1
],
case
[
2
])
out
=
fn
(
T
ensor
(
case
[
0
]),
case
[
1
],
case
[
2
])
assert
len
(
ref_out
)
==
len
(
out
)
for
idx
in
range
(
len
(
ref_out
)):
np
.
testing
.
assert_equal
(
ref_out
[
idx
],
out
[
idx
].
numpy
())
def
test_gather
():
x
=
Tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],])
index
=
Tensor
([[
0
,
1
],
[
1
,
0
],
[
1
,
1
]])
y
=
F
.
gather
(
x
,
1
,
index
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([[
1
,
2
],
[
4
,
3
],
[
6
,
6
]]).
astype
(
np
.
int32
)
)
def
test_scatter
():
x
=
Tensor
(
np
.
zeros
(
shape
=
(
3
,
5
),
dtype
=
np
.
float32
))
source
=
Tensor
(
[
[
0.9935
,
0.9465
,
0.2256
,
0.8926
,
0.4396
],
[
0.7723
,
0.0718
,
0.5939
,
0.357
,
0.4576
],
]
)
index
=
Tensor
([[
0
,
2
,
0
,
2
,
1
],
[
2
,
0
,
1
,
1
,
2
]])
y
=
F
.
scatter
(
x
,
-
2
,
index
,
source
)
np
.
testing
.
assert_equal
(
y
.
numpy
().
round
(
decimals
=
4
),
np
.
array
(
[
[
0.9935
,
0.0718
,
0.2256
,
0.0
,
0.0
],
[
0.0
,
0.0
,
0.5939
,
0.357
,
0.4396
],
[
0.7723
,
0.9465
,
0.0
,
0.8926
,
0.4576
],
]
).
astype
(
np
.
float32
),
)
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_swapaxes
(
is_varnode
):
if
is_varnode
:
...
...
@@ -221,7 +294,7 @@ def test_swapaxes(is_varnode):
else
:
network
=
None
x
=
t
ensor
(
np
.
array
([[
1
,
2
,
3
]],
dtype
=
np
.
int32
))
x
=
T
ensor
(
np
.
array
([[
1
,
2
,
3
]],
dtype
=
np
.
int32
))
y
=
F
.
swapaxes
(
x
,
0
,
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([[
1
],
[
2
],
[
3
]]).
astype
(
np
.
int32
))
...
...
@@ -280,15 +353,15 @@ def test_broadcast_auto_infer(is_varnode):
def
test_reshape_on_empty_tensor
(
is_trace
):
input1_shape
=
(
100
,
0
,
1
)
output1_shape
=
(
100
,
0
,
10
)
data1
=
t
ensor
(
np
.
random
.
random
(
input1_shape
).
astype
(
np
.
float32
))
data1
=
T
ensor
(
np
.
random
.
random
(
input1_shape
).
astype
(
np
.
float32
))
input2_shape
=
(
10
,
0
)
output2_shape
=
(
0
,)
data2
=
t
ensor
(
np
.
random
.
random
(
input2_shape
).
astype
(
np
.
float32
))
data2
=
T
ensor
(
np
.
random
.
random
(
input2_shape
).
astype
(
np
.
float32
))
input3_shape
=
(
10
,
0
,
10
)
output3_shape
=
(
0
,
1
,
2
,
3
)
data3
=
t
ensor
(
np
.
random
.
random
(
input3_shape
).
astype
(
np
.
float32
))
data3
=
T
ensor
(
np
.
random
.
random
(
input3_shape
).
astype
(
np
.
float32
))
def
comp
(
out
,
target_shp
):
assert
out
.
_tuple_shape
==
target_shp
...
...
@@ -338,7 +411,7 @@ def test_reshape_shape_inference(is_varnode):
def
check_shape
(
output
,
target
):
source
=
output
.
shape
if
isinstance
(
source
,
t
ensor
):
if
isinstance
(
source
,
T
ensor
):
source
=
source
.
numpy
()
np
.
testing
.
assert_equal
(
source
,
target
.
shape
)
...
...
@@ -366,6 +439,10 @@ def test_squeeze(is_varnode):
else
:
network
=
None
x
=
Tensor
(
np
.
array
([
1
,
2
],
dtype
=
np
.
int32
).
reshape
(
1
,
1
,
2
,
1
))
y
=
F
.
squeeze
(
x
,
-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([[[
1
,
2
]]]).
astype
(
np
.
int32
))
x
=
np
.
arange
(
6
,
dtype
=
"float32"
).
reshape
(
1
,
2
,
3
,
1
)
xx
=
make_tensor
(
x
,
network
)
...
...
@@ -385,6 +462,12 @@ def test_expand_dims(is_varnode):
else
:
network
=
None
x
=
Tensor
(
np
.
arange
(
1
,
7
,
dtype
=
np
.
int32
).
reshape
(
2
,
3
))
y
=
F
.
expand_dims
(
x
,
-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([[[
1
],
[
2
],
[
3
]],
[[
4
],
[
5
],
[
6
]]]).
astype
(
np
.
int32
)
)
x
=
np
.
arange
(
6
,
dtype
=
"float32"
).
reshape
(
2
,
3
)
xx
=
make_tensor
(
x
,
network
)
...
...
@@ -533,6 +616,22 @@ def test_flatten(is_varnode):
else
:
network
=
None
inp_shape
=
(
2
,
2
,
3
,
3
)
x
=
Tensor
(
np
.
arange
(
36
,
dtype
=
np
.
int32
).
reshape
(
inp_shape
),)
y
=
F
.
flatten
(
x
,
-
2
,
-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
(
[
[[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
],
[
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]],
[
[
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
],
[
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
],
],
]
).
astype
(
np
.
int32
),
)
data0_shape
=
(
2
,
3
,
4
,
5
)
data1_shape
=
(
4
,
5
,
6
,
7
)
data0
=
np
.
random
.
random
(
data0_shape
).
astype
(
np
.
float32
)
...
...
@@ -616,15 +715,15 @@ def test_broadcast(is_varnode):
def
test_broadcast_on_empty_tensor
(
is_trace
):
input1_shape
=
(
100
,
0
,
1
)
output1_shape
=
(
100
,
0
,
10
)
data1
=
t
ensor
(
np
.
random
.
random
(
input1_shape
).
astype
(
np
.
float32
))
data1
=
T
ensor
(
np
.
random
.
random
(
input1_shape
).
astype
(
np
.
float32
))
input2_shape
=
(
10
,
0
)
output2_shape
=
(
10
,
10
,
0
)
data2
=
t
ensor
(
np
.
random
.
random
(
input2_shape
).
astype
(
np
.
float32
))
data2
=
T
ensor
(
np
.
random
.
random
(
input2_shape
).
astype
(
np
.
float32
))
input3_shape
=
(
0
,
0
,
1
,
10
)
output3_shape
=
(
10
,
0
,
0
,
10
,
10
)
data3
=
t
ensor
(
np
.
random
.
random
(
input3_shape
).
astype
(
np
.
float32
))
data3
=
T
ensor
(
np
.
random
.
random
(
input3_shape
).
astype
(
np
.
float32
))
def
comp
(
out
,
target_shp
):
assert
out
.
_tuple_shape
==
target_shp
...
...
@@ -705,7 +804,7 @@ def test_utils_astensor1d(is_varnode):
def
test_device
():
x
=
t
ensor
([
1
,
2
,
3
],
dtype
=
"float32"
)
x
=
T
ensor
([
1
,
2
,
3
],
dtype
=
"float32"
)
y1
=
F
.
eye
(
x
.
shape
,
dtype
=
"float32"
)
y2
=
F
.
eye
(
x
.
shape
,
dtype
=
"float32"
,
device
=
None
)
...
...
@@ -789,7 +888,7 @@ def test_copy_d2d(is_varnode):
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
,
True
,
False
])
def
test_copy_empty
(
shape
,
device_src
,
device_dst
,
is_symbolic
):
inp
=
t
ensor
(
np
.
random
.
randn
(
*
shape
).
astype
(
"float32"
),
device
=
device_src
)
inp
=
T
ensor
(
np
.
random
.
randn
(
*
shape
).
astype
(
"float32"
),
device
=
device_src
)
def
func
(
inp
):
return
F
.
copy
(
inp
,
device_dst
)
...
...
@@ -885,6 +984,12 @@ def test_roll(shape, shifts, axis, is_varnode):
else
:
network
=
None
x
=
Tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
]],
np
.
int32
)
y
=
F
.
roll
(
x
,
1
,
-
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([[
2
,
1
],
[
4
,
3
],
[
6
,
5
]]).
astype
(
np
.
int32
)
)
inp
=
np
.
random
.
randn
(
*
shape
).
astype
(
"float32"
)
def
func
(
inp
):
...
...
@@ -904,7 +1009,7 @@ def test_roll(shape, shifts, axis, is_varnode):
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
,
True
,
False
])
def
test_roll_empty_tensor
(
shape
,
shifts
,
axis
,
is_symbolic
):
inp
=
t
ensor
(
np
.
random
.
randn
(
*
shape
).
astype
(
"float32"
))
inp
=
T
ensor
(
np
.
random
.
randn
(
*
shape
).
astype
(
"float32"
))
def
func
(
inp
):
return
F
.
roll
(
inp
,
shifts
,
axis
)
...
...
imperative/src/impl/ops/indexing.cpp
浏览文件 @
24c5c19b
#include "../dnn_op_helper.h"
#include "megbrain/imperative/ops/autogen.h"
#include "../op_trait.h"
#include "megbrain/opr/indexing.h"
#include "megdnn/oprs/general.h"
namespace
mgb
{
namespace
imperative
{
...
...
@@ -12,10 +14,8 @@ namespace indexing_one_hot {
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
)
{
auto
&
op
=
def
.
cast_final_safe
<
IndexingOneHot
>
();
auto
&&
op
=
def
.
cast_final_safe
<
IndexingOneHot
>
();
mgb_assert
(
input_descs
.
size
()
==
2
,
"IndexingOneHot expects two inputs"
);
auto
comp_node
=
input_descs
[
0
].
comp_node
;
TensorLayout
src
=
input_descs
[
0
].
layout
,
index
=
input_descs
[
1
].
layout
;
...
...
@@ -28,10 +28,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert
(
src
.
ndim
>=
2
,
"src ndim must be at least 2"
);
mgb_assert
(
src
.
is_contiguous
(),
"src should be contiguous"
);
mgb_assert
(
op
.
axis
>=
0
&&
op
.
axis
<
src
.
ndim
,
"axis %d not exists in src"
,
op
.
axis
);
-
static_cast
<
int
>
(
src
.
ndim
)
<=
op
.
axis
&&
op
.
axis
<
static_cast
<
int
>
(
src
.
ndim
),
"axis %d not exists in src"
,
op
.
axis
);
int
real_axis
=
static_cast
<
int
>
(
op
.
axis
);
if
(
real_axis
<
0
)
{
real_axis
+=
static_cast
<
int
>
(
src
.
ndim
);
}
TensorLayout
dst
=
src
;
dst
.
shape
[
op
.
axis
]
=
1
;
dst
.
shape
[
real_
axis
]
=
1
;
dst
.
init_contiguous_stride
();
if
(
!
index
.
ndim
)
{
...
...
@@ -40,24 +45,128 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert
(
index
.
is_contiguous
(),
"index should be all contiguous"
);
mgb_assert
(
index
.
eq_shape
(
src
.
remove_axis
(
op
.
axis
)),
"index shape doesn't match src"
);
index
.
eq_shape
(
src
.
remove_axis
(
real_axis
)),
"index shape doesn't match src"
);
return
{{{
dst
,
comp_node
}},
true
};
}
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
IndexingOneHot
&>
(
def
);
auto
&&
op
=
def
.
cast_final_safe
<
IndexingOneHot
>
(
);
mgb_assert
(
inputs
.
size
()
==
2
);
int
real_axis
=
static_cast
<
int
>
(
op
.
axis
);
if
(
real_axis
<
0
)
{
real_axis
+=
static_cast
<
int
>
(
op
.
ndim
);
}
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
IndexingOneHot
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
);
return
opr
::
IndexingOneHot
::
make
(
inputs
[
0
],
inputs
[
1
],
real_axis
,
config
);
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
SmallVector
<
TensorPtr
>
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
op
=
def
.
cast_final_safe
<
IndexingOneHot
>
();
auto
&&
inp
=
inputs
[
0
];
auto
&&
index
=
inputs
[
1
];
TensorLayout
layout
=
inp
->
layout
();
TensorLayout
index_layout
=
index
->
layout
();
DnnOprCaller
<
megdnn
::
IndexingOneHot
>
dnn_op
(
inp
->
comp_node
());
auto
&&
indexing_one_hot_param
=
dnn_op
.
op
->
param
();
int
real_axis
=
static_cast
<
int
>
(
op
.
axis
);
if
(
real_axis
<
0
)
{
real_axis
+=
static_cast
<
int
>
(
layout
.
ndim
);
}
mgb_assert
(
0
<=
real_axis
&&
real_axis
<
static_cast
<
int
>
(
layout
.
ndim
),
"Dimension out of range (expected to be in range of [%d, %d], but got %d)"
,
0
,
static_cast
<
int
>
(
layout
.
ndim
)
-
1
,
op
.
axis
);
indexing_one_hot_param
=
real_axis
;
TensorLayout
tlayout
;
dnn_op
.
op
->
deduce_layout
(
layout
,
index_layout
,
tlayout
);
TensorPtr
out
=
Tensor
::
make
(
tlayout
,
inp
->
comp_node
());
megdnn
::
TensorND
in
=
inp
->
dnn_tensor
();
megdnn
::
TensorND
ind
=
index
->
dnn_tensor
();
TensorLayout
m_layout
(
{
dnn_op
.
op
->
get_workspace_in_bytes
(
layout
,
index_layout
,
tlayout
)},
dtype
::
Byte
());
auto
dnn_workspace
=
dnn_op
.
create_workspace
(
m_layout
);
dnn_op
.
op
->
exec
(
in
,
ind
,
out
->
dnn_tensor
(),
dnn_workspace
);
return
{
out
};
}
OP_TRAIT_REG
(
IndexingOneHot
,
IndexingOneHot
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_var_node
(
apply_on_var_node
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace indexing_one_hot
namespace
indexing_set_one_hot
{
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
)
{
mgb_assert
(
input_descs
.
size
()
==
3
,
"IndexingSetOneHot expects three inputs"
);
auto
comp_node
=
input_descs
[
0
].
comp_node
;
TensorLayout
src
=
input_descs
[
0
].
layout
,
index
=
input_descs
[
1
].
layout
;
mgb_assert
(
index
.
dtype
==
dtype
::
Int32
(),
"index dtype must be int32"
);
if
(
!
src
.
ndim
)
{
return
{{{{{},
src
.
dtype
},
comp_node
}},
false
};
}
mgb_assert
(
src
.
is_contiguous
(),
"src should be contiguous"
);
return
{{
input_descs
[
0
]},
true
};
}
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
IndexingSetOneHot
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
3
);
int
real_axis
=
static_cast
<
int
>
(
op
.
axis
);
if
(
real_axis
<
0
)
{
real_axis
+=
static_cast
<
int
>
(
op
.
ndim
);
}
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
IndexingSetOneHot
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
real_axis
,
config
);
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
SmallVector
<
TensorPtr
>
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
op
=
def
.
cast_final_safe
<
IndexingSetOneHot
>
();
auto
&&
inp
=
inputs
[
0
];
auto
&&
index
=
inputs
[
1
];
auto
&&
sub
=
inputs
[
2
];
TensorLayout
layout
=
inp
->
layout
();
TensorLayout
index_layout
=
index
->
layout
();
TensorLayout
tlayout
=
sub
->
layout
();
mgb_assert
(
layout
.
is_contiguous
());
DnnOprCaller
<
megdnn
::
IndexingSetOneHot
>
dnn_op
(
inp
->
comp_node
());
auto
&&
indexing_one_hot_param
=
dnn_op
.
op
->
param
();
int
real_axis
=
static_cast
<
int
>
(
op
.
axis
);
if
(
real_axis
<
0
)
{
real_axis
+=
static_cast
<
int
>
(
layout
.
ndim
);
}
indexing_one_hot_param
=
real_axis
;
TensorPtr
out
=
Tensor
::
make
(
layout
,
inp
->
comp_node
());
out
->
dev_tensor
().
copy_from_fixlayout
(
inp
->
dev_tensor
());
megdnn
::
TensorND
in
=
inp
->
dnn_tensor
();
megdnn
::
TensorND
ind
=
index
->
dnn_tensor
();
megdnn
::
TensorND
su
=
sub
->
dnn_tensor
();
TensorLayout
m_layout
(
{
dnn_op
.
op
->
get_workspace_in_bytes
(
layout
,
index_layout
,
tlayout
)},
dtype
::
Byte
());
auto
dnn_workspace
=
dnn_op
.
create_workspace
(
m_layout
);
dnn_op
.
op
->
exec
(
out
->
dnn_tensor
(),
ind
,
su
,
dnn_workspace
);
return
{
out
};
}
OP_TRAIT_REG
(
IndexingSetOneHot
,
IndexingSetOneHot
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_var_node
(
apply_on_var_node
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace indexing_set_one_hot
}
// anonymous namespace
}
// namespace imperative
}
// namespace mgb
...
...
imperative/src/impl/ops/specializations.cpp
浏览文件 @
24c5c19b
...
...
@@ -372,21 +372,6 @@ OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallba
}
// namespace group_local
}
// namespace
namespace
{
namespace
indexing_set_one_hot
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
IndexingSetOneHot
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
3
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
IndexingSetOneHot
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
IndexingSetOneHot
,
IndexingSetOneHot
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}
// namespace indexing_set_one_hot
}
// namespace
namespace
{
namespace
typecvt
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
24c5c19b
...
...
@@ -108,9 +108,17 @@ def Remap: MgbHashableOp<"Remap", [RemapParam]>;
def Resize: MgbHashableOp<"Resize", [ResizeParam]>;
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>;
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]> {
let extraArguments = (ins
MgbI32Attr:$ndim
);
}
def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>;
def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]> {
let extraArguments = (ins
MgbI32Attr:$ndim
);
}
def Copy: MgbHashableOp<"Copy"> {
let extraArguments = (ins
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录